blob: 64e3cfe404703d6960fd54b01cb7cf8d9517dd2b [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/INEKernel.h"
31#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010032#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010033#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000034#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010035
Michalis Spyroubcf8a962018-10-12 10:51:31 +010036#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010037#include <arm_neon.h>
38
Michalis Spyroubcf8a962018-10-12 10:51:31 +010039namespace arm_compute
40{
Georgios Pinitasd9769582017-08-03 10:19:40 +010041namespace
42{
Michalis Spyrouaea14c62019-01-03 11:10:25 +000043uint32x4x4_t calculate_index(uint32_t idx, float32x4_t a, float32x4_t b, uint32x4x4_t c, ReductionOperation op, int axis)
44{
45 uint32x4_t mask{ 0 };
46 if(op == ReductionOperation::ARG_IDX_MIN)
47 {
48 mask = wrapper::vcgt(b, a);
49 }
50 else
51 {
52 mask = wrapper::vclt(b, a);
53 }
54
55 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
56 if(axis != 0)
57 {
58 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
59 }
60 uint32x4x4_t res = { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 };
61
62 return res;
63}
64
65uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
66{
67 uint32x4x4_t mask{ 0 };
68 uint8x16_t mask_u8{ 0 };
69 if(op == ReductionOperation::ARG_IDX_MIN)
70 {
71 mask_u8 = wrapper::vcgt(b, a);
72 }
73 else
74 {
75 mask_u8 = wrapper::vclt(b, a);
76 }
77 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgetlow(mask_u8))));
78 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgetlow(mask_u8))));
79 mask.val[2] = wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgethigh(mask_u8))));
80 mask.val[3] = wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgethigh(mask_u8))));
81 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
82 { idx + 4, idx + 5, idx + 6, idx + 7 },
83 { idx + 8, idx + 9, idx + 10, idx + 11 },
84 { idx + 12, idx + 13, idx + 14, idx + 15 }
85 }
86 };
87 if(axis != 0)
88 {
89 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
90 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
91 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
92 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
93 }
94 uint32x4x4_t res = { vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
95 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
96 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
97 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
98 };
99
100 return res;
101}
102
103uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_value, ReductionOperation op)
104{
105 uint32x4_t res_idx_mask{ 0 };
106 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
107
108 if(op == ReductionOperation::ARG_IDX_MIN)
109 {
110 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
111 pmin = wrapper::vpmin(pmin, pmin);
112 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
113 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
114 }
115 else
116 {
117 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
118 pmax = wrapper::vpmax(pmax, pmax);
119 auto mask = vceqq_f32(vec_res_value, wrapper::vcombine(pmax, pmax));
120 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
121 }
122
123 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
124 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
125 pmin = wrapper::vpmin(pmin, pmin);
126 uint32_t res = wrapper::vgetlane(pmin, 0);
127
128 return (res - 0xFFFFFFFF);
129}
130
131uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
132{
133 uint32x4x4_t res_idx_mask{ 0 };
134 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
135 uint8x16_t mask_u8{ 0 };
136 if(op == ReductionOperation::ARG_IDX_MIN)
137 {
138 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
139 pmin = wrapper::vpmin(pmin, pmin);
140 pmin = wrapper::vpmin(pmin, pmin);
141 pmin = wrapper::vpmin(pmin, pmin);
142 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
143 }
144 else
145 {
146 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
147 pmax = wrapper::vpmax(pmax, pmax);
148 pmax = wrapper::vpmax(pmax, pmax);
149 pmax = wrapper::vpmax(pmax, pmax);
150 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
151 }
152
153 // Widen vectors
154 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
155 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
156 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
157 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
158 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
159 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
160 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
161 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
162 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
163 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
164 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
165 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
166 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
167 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
168
169 uint32_t res = 0xFFFFFFFF;
170 int iter = 0;
171 do
172 {
173 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
174 pmin = wrapper::vpmin(pmin, pmin);
175 res = std::min(wrapper::vgetlane(pmin, 0), res);
176 iter++;
177 }
178 while(iter < 4);
179
180 return (res - 0xFFFFFFFF);
181}
182#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
183uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
184{
185 uint32x4x2_t mask{ 0 };
186 uint16x8_t mask_u16{ 0 };
187 if(op == ReductionOperation::ARG_IDX_MIN)
188 {
189 mask_u16 = wrapper::vcgt(b, a);
190 }
191 else
192 {
193 mask_u16 = wrapper::vclt(b, a);
194 }
195 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
196 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
197 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
198 { idx + 4, idx + 5, idx + 6, idx + 7 }
199 }
200 };
201 if(axis != 0)
202 {
203 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
204 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
205 }
206 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
207 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
208 0, 0
209 };
210
211 return res;
212}
213
214uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
215{
216 uint32x4x2_t res_idx_mask{ 0 };
217 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
218 uint16x8_t mask_u16;
219 if(op == ReductionOperation::ARG_IDX_MIN)
220 {
221 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
222 pmin = wrapper::vpmin(pmin, pmin);
223 pmin = wrapper::vpmin(pmin, pmin);
224 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
225 }
226 else
227 {
228 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
229 pmax = wrapper::vpmax(pmax, pmax);
230 pmax = wrapper::vpmax(pmax, pmax);
231 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
232 }
233
234 // Widen vectors
235 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
236 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
237 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
238 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
239 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
240 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
241
242 uint32_t res = 0xFFFFFFFF;
243 int iter = 0;
244 do
245 {
246 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
247 pmin = wrapper::vpmin(pmin, pmin);
248 res = std::min(wrapper::vgetlane(pmin, 0), res);
249 iter++;
250 }
251 while(iter < 2);
252
253 return (res - 0xFFFFFFFF);
254}
255#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
256
Georgios Pinitasd9769582017-08-03 10:19:40 +0100257template <class F>
258class Reducer
259{
260public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000261 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100262 {
263 // Set out window
264 Window out_window(window);
265 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
266
267 // Get first input and output slices
268 Window in_slice = window.first_slice_window_1D();
269 Window out_slice = out_window.first_slice_window_1D();
270
271 do
272 {
273 Iterator in(input, in_slice);
274 Iterator out(output, out_slice);
275
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000276 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100277 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100278 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
279 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000280 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100281 {
282 // Set in window
283 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000284 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100285
286 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000287 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100288
289 // Get first input and output slices
290 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000291 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100292
293 do
294 {
295 Iterator in(input, in_slice);
296 Iterator out(output, out_slice);
297
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000298 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100299 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000300 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100301 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000302 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100303 {
304 // Set in window
305 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000306 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100307
308 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000309 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100310
311 // Get first input and output slices
312 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000313 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100314
315 do
316 {
317 Iterator in(input, in_slice);
318 Iterator out(output, out_slice);
319
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000320 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100321 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000322 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100323 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000324 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100325 {
326 // Set in/out window
327 Window in_window(window);
328 Window out_window(window);
329
330 in_window.set(3, Window::Dimension(0, 1, 1));
331 out_window.set(3, Window::Dimension(0, 1, 1));
332
333 // Get first input and output slices
334 Window in_slice = in_window.first_slice_window_4D();
335 Window out_slice = out_window.first_slice_window_4D();
336
337 do
338 {
339 Iterator in(input, in_slice);
340 Iterator out(output, out_slice);
341
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000342 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100343 }
344 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100345 }
346};
347
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000348template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100349struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100350{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100351 /** NEON vector tag type. */
352 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
353
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000354 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100355 {
356 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000357 auto init_res_value = static_cast<T>(0.f);
358 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
359 {
360 init_res_value = *reinterpret_cast<T *>(input.ptr());
361 }
362 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
363 uint32x4x4_t vec_res_idx{ 0 };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100364
365 execute_window_loop(in_slice, [&](const Coordinates & id)
366 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100367 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
368 const auto vec_elements = wrapper::vloadq(in_ptr);
369
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000370 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000372 case ReductionOperation::SUM_SQUARE:
373 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
374 break;
375 case ReductionOperation::MEAN_SUM:
376 case ReductionOperation::SUM:
377 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
378 break;
379 case ReductionOperation::ARG_IDX_MIN:
380 {
381 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
382 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
383 vec_res_value = temp_vec_res_value;
384 break;
385 }
386 case ReductionOperation::ARG_IDX_MAX:
387 {
388 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
389 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
390 vec_res_value = temp_vec_res_value;
391 break;
392 }
393 default:
394 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100395 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100396 },
397 input);
398
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000399 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000400 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000401 case ReductionOperation::SUM:
402 case ReductionOperation::SUM_SQUARE:
403 case ReductionOperation::MEAN_SUM:
404 {
405 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
406 for(int i = 0; i < S / 4; ++i)
407 {
408 carry_res = wrapper::vpadd(carry_res, carry_res);
409 }
410 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100411
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000412 if(op == ReductionOperation::MEAN_SUM)
413 {
414 res /= in_info.dimension(0);
415 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100416
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000417 *(reinterpret_cast<T *>(output.ptr())) = res;
418 break;
419 }
420 case ReductionOperation::ARG_IDX_MIN:
421 case ReductionOperation::ARG_IDX_MAX:
422 {
423 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
424 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
425 break;
426 }
427 default:
428 ARM_COMPUTE_ERROR("Not supported");
429 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100430 }
431};
432
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100433struct RedOpX_qasymm8
434{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000435 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100436 {
437 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000438 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
439 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
440 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
441 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100442
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000443 uint8x16_t vec_res_value;
444 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
445 {
446 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
447 }
448
449 uint32x4x4_t vec_res_idx{ 0 };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100450 execute_window_loop(in_slice, [&](const Coordinates & id)
451 {
452 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000453 switch(op)
454 {
455 case ReductionOperation::SUM:
456 case ReductionOperation::MEAN_SUM:
457 {
458 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
459 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100460
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000461 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
462 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
463 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
464 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100465
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000466 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
467 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
468 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
469 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
470 break;
471 }
472 case ReductionOperation::ARG_IDX_MIN:
473 {
474 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
475 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
476 vec_res_value = temp_vec_res_value;
477 break;
478 }
479 case ReductionOperation::ARG_IDX_MAX:
480 {
481 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
482 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
483 vec_res_value = temp_vec_res_value;
484 break;
485 }
486 default:
487 ARM_COMPUTE_ERROR("Not supported");
488 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100489 },
490 input);
491
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000492 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100493 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000494 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
495 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100496 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000497 else
498 {
499 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
500 carry_res = wrapper::vadd(carry_res, vec_res_value3);
501 carry_res = wrapper::vadd(carry_res, vec_res_value4);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100502
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000503 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
504 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
505 auto res = wrapper::vgetlane(carry_paddition, 0);
506
507 if(op == ReductionOperation::MEAN_SUM)
508 {
509 res /= in_info.dimension(0);
510 }
511
512 *(output.ptr()) = static_cast<uint8_t>(res);
513 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100514 }
515};
516
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000517template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100518struct RedOpYZW
519{
520 /** NEON vector tag type. */
521 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000522 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100523
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000524 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100525 {
526 ARM_COMPUTE_UNUSED(out_slice);
527
528 execute_window_loop(in_slice, [&](const Coordinates & id)
529 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000530 neon_vector vec_res_value;
531 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
532 {
533 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
534 }
535 else
536 {
537 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
538 }
539 uint32x4x4_t vec_res_idx{ 0 };
540
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100541 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
542 {
543 T *in_ptr;
544 switch(axis)
545 {
546 case 1:
547 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
548 break;
549 case 2:
550 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
551 break;
552 case 3:
553 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
554 break;
555 default:
556 ARM_COMPUTE_ERROR("Not supported");
557 }
558 const auto vec_elements = wrapper::vloadq(in_ptr);
559
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000560 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100561 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000562 case ReductionOperation::SUM:
563 case ReductionOperation::MEAN_SUM:
564 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
565 break;
566 case ReductionOperation::SUM_SQUARE:
567 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
568 break;
569 case ReductionOperation::ARG_IDX_MIN:
570 {
571 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
572 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
573 vec_res_value = temp_vec_res_value;
574 break;
575 }
576 case ReductionOperation::ARG_IDX_MAX:
577 {
578 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
579 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
580 vec_res_value = temp_vec_res_value;
581 break;
582 }
583 default:
584 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100585 }
586 }
587
588 if(op == ReductionOperation::MEAN_SUM)
589 {
590 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000591 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100592 }
593
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000594 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
595 {
596 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
597 }
598 else
599 {
600 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
601 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100602 },
603 input, output);
604 }
605};
606
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100607struct RedOpYZW_qasymm8
608{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000609 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100610 {
611 ARM_COMPUTE_UNUSED(out_slice);
612
613 execute_window_loop(in_slice, [&](const Coordinates & id)
614 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000615 uint32x4x4_t vec_res_idx{ 0 };
616 auto vec_res_value1 = vdupq_n_u32(0);
617 auto vec_res_value2 = vdupq_n_u32(0);
618 auto vec_res_value3 = vdupq_n_u32(0);
619 auto vec_res_value4 = vdupq_n_u32(0);
620 auto vec_res_value = wrapper::vloadq(input.ptr());
621
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100622 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
623 {
624 uint8_t *in_ptr;
625 switch(axis)
626 {
627 case 1:
628 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim));
629 break;
630 case 2:
631 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim));
632 break;
633 case 3:
634 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim));
635 break;
636 default:
637 ARM_COMPUTE_ERROR("Not supported");
638 }
639 const auto vec_elements = wrapper::vloadq(in_ptr);
640
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000641 switch(op)
642 {
643 case ReductionOperation::SUM:
644 case ReductionOperation::MEAN_SUM:
645 {
646 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
647 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100648
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000649 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
650 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
651 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
652 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100653
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000654 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
655 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
656 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
657 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
658 break;
659 }
660 case ReductionOperation::ARG_IDX_MIN:
661 {
662 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
663 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
664 vec_res_value = temp_vec_res_value;
665 break;
666 }
667 case ReductionOperation::ARG_IDX_MAX:
668 {
669 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
670 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
671 vec_res_value = temp_vec_res_value;
672 break;
673 }
674 default:
675 ARM_COMPUTE_ERROR("Not supported");
676 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100677 }
678
679 if(op == ReductionOperation::MEAN_SUM)
680 {
681 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000682 const auto vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
683 const auto vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
684 const auto vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
685 const auto vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100686
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000687 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
688 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
689 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
690 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
691 }
692 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
693 {
694 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
695 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
696 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
697 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
698 }
699 else
700 {
701 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
702 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
703 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
704 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100705 }
706
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100707 },
708 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100709 }
710};
711
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000712void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100713{
714 switch(axis)
715 {
716 case 0:
717 switch(input->info()->data_type())
718 {
719 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000720 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100721#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
722 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000723 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100724#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
725 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000726 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100727 default:
728 ARM_COMPUTE_ERROR("Not supported");
729 }
730 case 1:
731 switch(input->info()->data_type())
732 {
733 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000734 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100735#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
736 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000737 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100738#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
739 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000740 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100741 default:
742 ARM_COMPUTE_ERROR("Not supported");
743 }
744 case 2:
745 switch(input->info()->data_type())
746 {
747 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000748 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100749#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
750 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000751 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100752#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
753 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000754 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100755 default:
756 ARM_COMPUTE_ERROR("Not supported");
757 }
758 case 3:
759 switch(input->info()->data_type())
760 {
761 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000762 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100763#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
764 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000765 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100766#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
767 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000768 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100769 default:
770 ARM_COMPUTE_ERROR("Not supported");
771 }
772 default:
773 ARM_COMPUTE_ERROR("Unsupported reduction axis");
774 }
775}
John Richardson73d4aef2018-05-08 14:34:33 +0100776
777Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
778{
779 ARM_COMPUTE_UNUSED(op);
780
781 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100782 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
John Richardson73d4aef2018-05-08 14:34:33 +0100783
784 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100785 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +0100786
787 if(output->total_size() != 0)
788 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000789 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
790 if(!is_arg_min_max)
791 {
792 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
793 }
794 else
795 {
796 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
797 }
798
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100799 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100800
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000801 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100802 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
803 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
804 }
805
806 return Status{};
807}
808
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000809std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +0100810{
811 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000812 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100813
814 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000815 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
816 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
817 auto_init_if_empty(*output, output_shape, 1, output_data_type);
John Richardson73d4aef2018-05-08 14:34:33 +0100818
819 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
820
821 // Configure kernel window
822 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
823 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
824 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
825
826 bool window_changed = update_window_and_padding(win, input_access, output_access);
827 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
828
829 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
830
831 return std::make_tuple(err, win);
832}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100833} // namespace
834
835NEReductionOperationKernel::NEReductionOperationKernel()
836 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
837{
838}
839
840BorderSize NEReductionOperationKernel::border_size() const
841{
842 return _border_size;
843}
844
845void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
846{
847 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100848
John Richardson73d4aef2018-05-08 14:34:33 +0100849 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100850
851 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
852
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100853 _input = input;
854 _output = output;
855 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
856 _op = op;
857 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100858
859 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000860 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100861
John Richardson73d4aef2018-05-08 14:34:33 +0100862 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100863
John Richardson73d4aef2018-05-08 14:34:33 +0100864 INEKernel::configure(std::get<1>(win_config));
865}
866
867Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
868{
869 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000870 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +0100871
872 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +0100873}
874
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100875void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100876{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100877 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100878 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
879 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
880
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000881 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100882}
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100883} // namespace arm_compute