blob: c6e853659c2721e8bd2e9f3d9887c43741336df1 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010034#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036
Michalis Spyroubcf8a962018-10-12 10:51:31 +010037#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038#include <arm_neon.h>
39
Michalis Spyroubcf8a962018-10-12 10:51:31 +010040namespace arm_compute
41{
Georgios Pinitasd9769582017-08-03 10:19:40 +010042namespace
43{
Michalis Spyroub9626ab2019-05-13 17:41:01 +010044template <typename T>
45uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000046{
47 uint32x4_t mask{ 0 };
48 if(op == ReductionOperation::ARG_IDX_MIN)
49 {
50 mask = wrapper::vcgt(b, a);
51 }
52 else
53 {
54 mask = wrapper::vclt(b, a);
55 }
56
57 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
58 if(axis != 0)
59 {
60 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
61 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000062 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000063
64 return res;
65}
66
67uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
68{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000069 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000070 uint8x16_t mask_u8{ 0 };
71 if(op == ReductionOperation::ARG_IDX_MIN)
72 {
73 mask_u8 = wrapper::vcgt(b, a);
74 }
75 else
76 {
77 mask_u8 = wrapper::vclt(b, a);
78 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000079 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
80 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
81 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
82 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
83 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
84 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
85
Michalis Spyrouaea14c62019-01-03 11:10:25 +000086 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
87 { idx + 4, idx + 5, idx + 6, idx + 7 },
88 { idx + 8, idx + 9, idx + 10, idx + 11 },
89 { idx + 12, idx + 13, idx + 14, idx + 15 }
90 }
91 };
92 if(axis != 0)
93 {
94 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
95 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
97 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
98 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000099 uint32x4x4_t res =
100 {
101 {
102 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
103 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
104 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
105 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
106 }
107 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000108
109 return res;
110}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100111
112// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
113float32x2_t calculate_min(float32x4_t in)
114{
115 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
116 return wrapper::vpmin(pmin, pmin);
117}
118
119// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
120float32x2_t calculate_max(float32x4_t in)
121{
122 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
123 return wrapper::vpmax(pmax, pmax);
124}
125// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
126int32x2_t calculate_min(int32x4_t in)
127{
128 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
129 return wrapper::vpmin(pmin, pmin);
130}
131
132// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
133int32x2_t calculate_max(int32x4_t in)
134{
135 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
136 return wrapper::vpmax(pmax, pmax);
137}
138
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100139template <typename T>
140uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000141{
142 uint32x4_t res_idx_mask{ 0 };
143 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
144
145 if(op == ReductionOperation::ARG_IDX_MIN)
146 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100147 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000148 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
149 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
150 }
151 else
152 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100153 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100154 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000155 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
156 }
157
158 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
159 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
160 pmin = wrapper::vpmin(pmin, pmin);
161 uint32_t res = wrapper::vgetlane(pmin, 0);
162
163 return (res - 0xFFFFFFFF);
164}
165
Usama Arifa4a08ad2019-05-20 12:38:33 +0100166// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
167inline uint8x8_t calculate_min(uint8x16_t in)
168{
169 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
170 pmin = wrapper::vpmin(pmin, pmin);
171 pmin = wrapper::vpmin(pmin, pmin);
172 return wrapper::vpmin(pmin, pmin);
173}
174// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
175inline uint8x8_t calculate_max(uint8x16_t in)
176{
177 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
178 pmax = wrapper::vpmax(pmax, pmax);
179 pmax = wrapper::vpmax(pmax, pmax);
180 return wrapper::vpmax(pmax, pmax);
181}
182
Usama Arif0a5a57a2019-05-23 14:20:33 +0100183template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000184uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
185{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000186 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000187 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
188 uint8x16_t mask_u8{ 0 };
189 if(op == ReductionOperation::ARG_IDX_MIN)
190 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100191 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000192 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
193 }
194 else
195 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100196 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000197 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
198 }
199
200 // Widen vectors
201 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
202 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
203 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
204 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
205 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
206 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
207 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
208 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
209 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
210 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
211 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
212 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
213 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
214 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
215
216 uint32_t res = 0xFFFFFFFF;
217 int iter = 0;
218 do
219 {
220 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
221 pmin = wrapper::vpmin(pmin, pmin);
222 res = std::min(wrapper::vgetlane(pmin, 0), res);
223 iter++;
224 }
225 while(iter < 4);
226
227 return (res - 0xFFFFFFFF);
228}
229#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
230uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
231{
232 uint32x4x2_t mask{ 0 };
233 uint16x8_t mask_u16{ 0 };
234 if(op == ReductionOperation::ARG_IDX_MIN)
235 {
236 mask_u16 = wrapper::vcgt(b, a);
237 }
238 else
239 {
240 mask_u16 = wrapper::vclt(b, a);
241 }
242 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
243 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
244 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
245 { idx + 4, idx + 5, idx + 6, idx + 7 }
246 }
247 };
248 if(axis != 0)
249 {
250 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
251 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
252 }
253 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
254 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
255 0, 0
256 };
257
258 return res;
259}
260
Usama Arifa4a08ad2019-05-20 12:38:33 +0100261// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
262inline float16x4_t calculate_min(float16x8_t in)
263{
264 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
265 pmin = wrapper::vpmin(pmin, pmin);
266 return wrapper::vpmin(pmin, pmin);
267}
268// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
269inline float16x4_t calculate_max(float16x8_t in)
270{
271 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
272 pmax = wrapper::vpmax(pmax, pmax);
273 return wrapper::vpmax(pmax, pmax);
274}
275
Usama Arif0a5a57a2019-05-23 14:20:33 +0100276template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000277uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
278{
279 uint32x4x2_t res_idx_mask{ 0 };
280 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
281 uint16x8_t mask_u16;
282 if(op == ReductionOperation::ARG_IDX_MIN)
283 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100284 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000285 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
286 }
287 else
288 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100289 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000290 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
291 }
292
293 // Widen vectors
294 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
295 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
296 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
297 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
298 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
299 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
300
301 uint32_t res = 0xFFFFFFFF;
302 int iter = 0;
303 do
304 {
305 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
306 pmin = wrapper::vpmin(pmin, pmin);
307 res = std::min(wrapper::vgetlane(pmin, 0), res);
308 iter++;
309 }
310 while(iter < 2);
311
312 return (res - 0xFFFFFFFF);
313}
314#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
315
Georgios Pinitasd9769582017-08-03 10:19:40 +0100316template <class F>
317class Reducer
318{
319public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000320 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100321 {
322 // Set out window
323 Window out_window(window);
324 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
325
326 // Get first input and output slices
327 Window in_slice = window.first_slice_window_1D();
328 Window out_slice = out_window.first_slice_window_1D();
329
330 do
331 {
332 Iterator in(input, in_slice);
333 Iterator out(output, out_slice);
334
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000335 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100336 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100337 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
338 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000339 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100340 {
341 // Set in window
342 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000343 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100344
345 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000346 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100347
348 // Get first input and output slices
349 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000350 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100351
352 do
353 {
354 Iterator in(input, in_slice);
355 Iterator out(output, out_slice);
356
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000357 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000359 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000361 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100362 {
363 // Set in window
364 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000365 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100366
367 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000368 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100369
370 // Get first input and output slices
371 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000372 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100373
374 do
375 {
376 Iterator in(input, in_slice);
377 Iterator out(output, out_slice);
378
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000379 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100380 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000381 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100382 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000383 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100384 {
385 // Set in/out window
386 Window in_window(window);
387 Window out_window(window);
388
389 in_window.set(3, Window::Dimension(0, 1, 1));
390 out_window.set(3, Window::Dimension(0, 1, 1));
391
392 // Get first input and output slices
393 Window in_slice = in_window.first_slice_window_4D();
394 Window out_slice = out_window.first_slice_window_4D();
395
396 do
397 {
398 Iterator in(input, in_slice);
399 Iterator out(output, out_slice);
400
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000401 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100402 }
403 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100404 }
405};
406
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000407template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100408struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100409{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100410 /** NEON vector tag type. */
411 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
412
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000413 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100414 {
415 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000416 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100417 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000418 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100419 case ReductionOperation::ARG_IDX_MAX:
420 case ReductionOperation::ARG_IDX_MIN:
421 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100422 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100423 {
424 init_res_value = *reinterpret_cast<T *>(input.ptr());
425 break;
426 }
427 case ReductionOperation::PROD:
428 {
429 init_res_value = static_cast<T>(1.f);
430 break;
431 }
432 default:
433 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000434 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000435 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000436 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100437
438 execute_window_loop(in_slice, [&](const Coordinates & id)
439 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100440 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
441 const auto vec_elements = wrapper::vloadq(in_ptr);
442
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000443 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100444 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000445 case ReductionOperation::SUM_SQUARE:
446 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
447 break;
448 case ReductionOperation::MEAN_SUM:
449 case ReductionOperation::SUM:
450 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
451 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000452 case ReductionOperation::PROD:
453 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
454 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000455 case ReductionOperation::ARG_IDX_MIN:
456 {
457 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100458 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000459 vec_res_value = temp_vec_res_value;
460 break;
461 }
462 case ReductionOperation::ARG_IDX_MAX:
463 {
464 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100465 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000466 vec_res_value = temp_vec_res_value;
467 break;
468 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100469 case ReductionOperation::MIN:
470 {
471 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
472 break;
473 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100474 case ReductionOperation::MAX:
475 {
476 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
477 break;
478 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000479 default:
480 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100481 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100482 },
483 input);
484
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000485 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000486 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000487 case ReductionOperation::SUM:
488 case ReductionOperation::SUM_SQUARE:
489 case ReductionOperation::MEAN_SUM:
490 {
491 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
492 for(int i = 0; i < S / 4; ++i)
493 {
494 carry_res = wrapper::vpadd(carry_res, carry_res);
495 }
496 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100497
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000498 if(op == ReductionOperation::MEAN_SUM)
499 {
500 res /= in_info.dimension(0);
501 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100502
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000503 *(reinterpret_cast<T *>(output.ptr())) = res;
504 break;
505 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000506 case ReductionOperation::PROD:
507 {
508 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
509 T res = 1;
510 for(int i = 0; i < S / 2; ++i)
511 {
512 res *= wrapper::vgetlane(carry_res, i);
513 }
514 *(reinterpret_cast<T *>(output.ptr())) = res;
515 break;
516 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000517 case ReductionOperation::ARG_IDX_MIN:
518 case ReductionOperation::ARG_IDX_MAX:
519 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100520 auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000521 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
522 break;
523 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100524 case ReductionOperation::MIN:
525 {
526 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
527 break;
528 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100529 case ReductionOperation::MAX:
530 {
531 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0);
532 break;
533 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000534 default:
535 ARM_COMPUTE_ERROR("Not supported");
536 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100537 }
538};
539
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100540struct RedOpX_qasymm8
541{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000542 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100543 {
544 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000545 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
546 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
547 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
548 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100549
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000550 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
551 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
552 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
553 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
554
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000555 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000556
Usama Arif28f0dd92019-05-20 13:44:34 +0100557 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000558 {
559 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
560 }
561
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000562 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100563 execute_window_loop(in_slice, [&](const Coordinates & id)
564 {
565 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000566 switch(op)
567 {
568 case ReductionOperation::SUM:
569 case ReductionOperation::MEAN_SUM:
570 {
571 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
572 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100573
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000574 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
575 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
576 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
577 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100578
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000579 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
580 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
581 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
582 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
583 break;
584 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000585 case ReductionOperation::PROD:
586 {
587 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
588 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
589
590 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
591 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
592
593 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
594 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
595 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
596 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
597
598 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
599 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
600 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
601 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
602
603 //de-quantize vec_elements
604 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
605 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
606 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
607 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
608
609 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
610 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
611 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
612 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
613 break;
614 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000615 case ReductionOperation::ARG_IDX_MIN:
616 {
617 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
618 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
619 vec_res_value = temp_vec_res_value;
620 break;
621 }
622 case ReductionOperation::ARG_IDX_MAX:
623 {
624 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
625 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
626 vec_res_value = temp_vec_res_value;
627 break;
628 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100629 case ReductionOperation::MIN:
630 {
631 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
632 break;
633 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100634 case ReductionOperation::MAX:
635 {
636 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
637 break;
638 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000639 default:
640 ARM_COMPUTE_ERROR("Not supported");
641 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100642 },
643 input);
644
Usama Arifa4a08ad2019-05-20 12:38:33 +0100645 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100646 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100647 case ReductionOperation::ARG_IDX_MIN:
648 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000649 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100650 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
651 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
652 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000653 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100654 case ReductionOperation::MIN:
655 {
656 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
657 break;
658 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100659 case ReductionOperation::MAX:
660 {
661 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
662 break;
663 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100664 case ReductionOperation::PROD:
665 {
666 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
667 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
668 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000669
Usama Arifa4a08ad2019-05-20 12:38:33 +0100670 float res = wrapper::vgetlane(carry_res, 0);
671 res *= wrapper::vgetlane(carry_res, 1);
672 res *= wrapper::vgetlane(carry_res, 2);
673 res *= wrapper::vgetlane(carry_res, 3);
674
675 //re-quantize result
676 res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
677 *(output.ptr()) = static_cast<uint8_t>(res);
678 break;
679 }
680 default:
681 {
682 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
683 carry_res = wrapper::vadd(carry_res, vec_res_value3);
684 carry_res = wrapper::vadd(carry_res, vec_res_value4);
685
686 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
687 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
688 auto res = wrapper::vgetlane(carry_paddition, 0);
689
690 if(op == ReductionOperation::MEAN_SUM)
691 {
692 res /= in_info.dimension(0);
693 }
694
695 *(output.ptr()) = static_cast<uint8_t>(res);
696 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000697 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100698 }
699};
700
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000701template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100702struct RedOpYZW
703{
704 /** NEON vector tag type. */
705 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000706 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100707
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000708 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100709 {
710 ARM_COMPUTE_UNUSED(out_slice);
711
giuros01154bc1c2019-03-26 17:44:40 +0000712 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100713 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000714 neon_vector vec_res_value = { 0 };
Usama Arifa4a08ad2019-05-20 12:38:33 +0100715 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000716 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100717 case ReductionOperation::ARG_IDX_MAX:
718 case ReductionOperation::ARG_IDX_MIN:
719 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100720 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100721 {
722 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
723 break;
724 }
725 case ReductionOperation::PROD:
726 {
727 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
728 break;
729 }
730 default:
731 {
732 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
733 break;
734 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000735 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000736 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000737
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100738 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
739 {
740 T *in_ptr;
741 switch(axis)
742 {
743 case 1:
744 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
745 break;
746 case 2:
747 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
748 break;
749 case 3:
750 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
751 break;
752 default:
753 ARM_COMPUTE_ERROR("Not supported");
754 }
755 const auto vec_elements = wrapper::vloadq(in_ptr);
756
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000757 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100758 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000759 case ReductionOperation::SUM:
760 case ReductionOperation::MEAN_SUM:
761 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
762 break;
763 case ReductionOperation::SUM_SQUARE:
764 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
765 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000766 case ReductionOperation::PROD:
767 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
768 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000769 case ReductionOperation::ARG_IDX_MIN:
770 {
771 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
772 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
773 vec_res_value = temp_vec_res_value;
774 break;
775 }
776 case ReductionOperation::ARG_IDX_MAX:
777 {
778 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
779 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
780 vec_res_value = temp_vec_res_value;
781 break;
782 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100783 case ReductionOperation::MIN:
784 {
785 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
786 break;
787 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100788 case ReductionOperation::MAX:
789 {
790 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
791 break;
792 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000793 default:
794 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100795 }
796 }
797
798 if(op == ReductionOperation::MEAN_SUM)
799 {
800 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000801 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100802 }
803
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000804 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
805 {
806 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
807 }
808 else
809 {
810 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
811 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100812 },
813 input, output);
814 }
815};
816
giuros01154bc1c2019-03-26 17:44:40 +0000817template <typename T, int S, int axis, ReductionOperation op>
818struct RedOpYZW_complex
819{
820 /** NEON vector tag type. */
821 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
822 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
823
824 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
825 {
826 ARM_COMPUTE_UNUSED(out_slice);
827 ARM_COMPUTE_ERROR_ON(axis != 2);
828
829 const size_t stride_z = in_info.strides_in_bytes()[axis];
830
831 execute_window_loop(in_slice, [&](const Coordinates &)
832 {
833 neon_vector vec_res_value_0 = { 0 };
834 neon_vector vec_res_value_1 = { 0 };
835
836 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
837 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
838
839 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
840 {
841 T *in_ptr_0;
842 T *in_ptr_1;
843 switch(axis)
844 {
845 case 2:
846 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
847 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
848 break;
849 default:
850 ARM_COMPUTE_ERROR("Not supported");
851 }
852 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
853 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
854
855 switch(op)
856 {
857 case ReductionOperation::SUM:
858 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
859 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
860 break;
861 default:
862 ARM_COMPUTE_ERROR("Not supported");
863 }
864 }
865
866 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
867 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
868
869 },
870 input, output);
871 }
872};
873
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100874struct RedOpYZW_qasymm8
875{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000876 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100877 {
878 ARM_COMPUTE_UNUSED(out_slice);
879
giuros01154bc1c2019-03-26 17:44:40 +0000880 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100881 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000882 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000883 auto vec_res_value1 = vdupq_n_u32(0);
884 auto vec_res_value2 = vdupq_n_u32(0);
885 auto vec_res_value3 = vdupq_n_u32(0);
886 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000887
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000888 auto vec_res_value1_f = vdupq_n_f32(1);
889 auto vec_res_value2_f = vdupq_n_f32(1);
890 auto vec_res_value3_f = vdupq_n_f32(1);
891 auto vec_res_value4_f = vdupq_n_f32(1);
892
893 auto vec_res_value = wrapper::vloadq(input.ptr());
894
895 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100896 {
897 uint8_t *in_ptr;
898 switch(axis)
899 {
900 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000901 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100902 break;
903 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000904 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100905 break;
906 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000907 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100908 break;
909 default:
910 ARM_COMPUTE_ERROR("Not supported");
911 }
912 const auto vec_elements = wrapper::vloadq(in_ptr);
913
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000914 switch(op)
915 {
916 case ReductionOperation::SUM:
917 case ReductionOperation::MEAN_SUM:
918 {
919 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
920 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100921
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000922 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
923 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
924 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
925 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100926
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000927 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
928 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
929 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
930 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
931 break;
932 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000933 case ReductionOperation::PROD:
934 {
935 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
936 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
937
938 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
939 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
940
941 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
942 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
943 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
944 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
945
946 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
947 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
948 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
949 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
950
951 //de-quantize vec_elements
952 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
953 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
954 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
955 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
956
957 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
958 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
959 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
960 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
961 break;
962 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000963 case ReductionOperation::ARG_IDX_MIN:
964 {
965 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000966 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000967 vec_res_value = temp_vec_res_value;
968 break;
969 }
970 case ReductionOperation::ARG_IDX_MAX:
971 {
972 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000973 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000974 vec_res_value = temp_vec_res_value;
975 break;
976 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100977 case ReductionOperation::MIN:
978 {
979 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
980 break;
981 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100982 case ReductionOperation::MAX:
983 {
984 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
985 break;
986 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000987 default:
988 ARM_COMPUTE_ERROR("Not supported");
989 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100990 }
991
992 if(op == ReductionOperation::MEAN_SUM)
993 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000994 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
995 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
996 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
997 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
998 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100999
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001000 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
1001 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1002 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1003 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1004 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001005 else if(op == ReductionOperation::PROD)
1006 {
1007 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
1008 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
1009
1010 //re-quantize
1011 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1012 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1013 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1014 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1015
1016 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
1017 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1018 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1019 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1020 }
1021
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001022 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1023 {
1024 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
1025 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
1026 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
1027 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
1028 }
Usama Arifa4a08ad2019-05-20 12:38:33 +01001029 else if(op == ReductionOperation::ARG_IDX_MIN)
1030 {
1031 wrapper::vstore(output.ptr(), vec_res_value);
1032 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001033 else
1034 {
1035 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1036 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1037 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1038 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001039 }
1040
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001041 },
1042 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001043 }
1044};
1045
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001046void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001047{
giuros01154bc1c2019-03-26 17:44:40 +00001048 const bool is_complex = (input->info()->num_channels() == 2);
1049
1050 if(is_complex)
1051 {
1052 switch(axis)
1053 {
1054 case 2:
1055 switch(input->info()->data_type())
1056 {
1057 case DataType::F32:
1058 switch(op)
1059 {
1060 case ReductionOperation::SUM:
1061 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1062 default:
1063 ARM_COMPUTE_ERROR("Not supported");
1064 }
1065 default:
1066 ARM_COMPUTE_ERROR("Not supported");
1067 }
1068 default:
1069 ARM_COMPUTE_ERROR("Not supported");
1070 }
1071 }
1072
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001073 switch(axis)
1074 {
1075 case 0:
1076 switch(input->info()->data_type())
1077 {
1078 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001079 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001080#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1081 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001082 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001083#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1084 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001085 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001086 case DataType::S32:
1087 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001088 default:
1089 ARM_COMPUTE_ERROR("Not supported");
1090 }
1091 case 1:
1092 switch(input->info()->data_type())
1093 {
1094 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001095 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001096#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1097 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001098 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001099#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1100 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001101 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001102 case DataType::S32:
1103 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001104 default:
1105 ARM_COMPUTE_ERROR("Not supported");
1106 }
1107 case 2:
1108 switch(input->info()->data_type())
1109 {
1110 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001111 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001112#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1113 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001114 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001115#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1116 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001117 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001118 case DataType::S32:
1119 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001120 default:
1121 ARM_COMPUTE_ERROR("Not supported");
1122 }
1123 case 3:
1124 switch(input->info()->data_type())
1125 {
1126 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001127 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001128#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1129 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001130 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001131#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1132 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001133 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001134 case DataType::S32:
1135 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001136 default:
1137 ARM_COMPUTE_ERROR("Not supported");
1138 }
1139 default:
1140 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1141 }
1142}
John Richardson73d4aef2018-05-08 14:34:33 +01001143
1144Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1145{
1146 ARM_COMPUTE_UNUSED(op);
1147
1148 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001149 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001150
1151 if(input->num_channels() == 1)
1152 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001153 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001154 }
1155 else
1156 {
1157 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1158 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1159 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1160 }
John Richardson73d4aef2018-05-08 14:34:33 +01001161
1162 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001163 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001164
1165 if(output->total_size() != 0)
1166 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001167 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1168 if(!is_arg_min_max)
1169 {
1170 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001171 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001172 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001173 }
1174 else
1175 {
1176 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
1177 }
John Richardson73d4aef2018-05-08 14:34:33 +01001178
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001179 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001180 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1181 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1182 }
1183
1184 return Status{};
1185}
1186
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001187std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001188{
1189 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001190 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001191
1192 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001193 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1194 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
giuros01154bc1c2019-03-26 17:44:40 +00001195 auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001196
1197 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
1198
1199 // Configure kernel window
1200 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
1201 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
1202 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
1203
1204 bool window_changed = update_window_and_padding(win, input_access, output_access);
1205 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1206
1207 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1208
1209 return std::make_tuple(err, win);
1210}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001211} // namespace
1212
1213NEReductionOperationKernel::NEReductionOperationKernel()
1214 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
1215{
1216}
1217
1218BorderSize NEReductionOperationKernel::border_size() const
1219{
1220 return _border_size;
1221}
1222
1223void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1224{
1225 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001226
John Richardson73d4aef2018-05-08 14:34:33 +01001227 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001228
1229 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
1230
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001231 _input = input;
1232 _output = output;
1233 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
1234 _op = op;
1235 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001236
1237 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001238 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001239
John Richardson73d4aef2018-05-08 14:34:33 +01001240 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001241
John Richardson73d4aef2018-05-08 14:34:33 +01001242 INEKernel::configure(std::get<1>(win_config));
1243}
1244
1245Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1246{
1247 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001248 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001249
1250 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001251}
1252
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001253void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001254{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001255 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001256 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1257 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1258
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001259 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001260}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001261} // namespace arm_compute