blob: b51d4b311fc955284ba4d7fe01421e2cca262fa9 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010034#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036
Michalis Spyroubcf8a962018-10-12 10:51:31 +010037#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038#include <arm_neon.h>
39
Michalis Spyroubcf8a962018-10-12 10:51:31 +010040namespace arm_compute
41{
Georgios Pinitasd9769582017-08-03 10:19:40 +010042namespace
43{
Michalis Spyroub9626ab2019-05-13 17:41:01 +010044template <typename T>
45uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000046{
47 uint32x4_t mask{ 0 };
48 if(op == ReductionOperation::ARG_IDX_MIN)
49 {
50 mask = wrapper::vcgt(b, a);
51 }
52 else
53 {
54 mask = wrapper::vclt(b, a);
55 }
56
57 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
58 if(axis != 0)
59 {
60 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
61 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000062 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000063
64 return res;
65}
66
67uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
68{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000069 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000070 uint8x16_t mask_u8{ 0 };
71 if(op == ReductionOperation::ARG_IDX_MIN)
72 {
73 mask_u8 = wrapper::vcgt(b, a);
74 }
75 else
76 {
77 mask_u8 = wrapper::vclt(b, a);
78 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000079 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
80 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
81 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
82 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
83 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
84 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
85
Michalis Spyrouaea14c62019-01-03 11:10:25 +000086 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
87 { idx + 4, idx + 5, idx + 6, idx + 7 },
88 { idx + 8, idx + 9, idx + 10, idx + 11 },
89 { idx + 12, idx + 13, idx + 14, idx + 15 }
90 }
91 };
92 if(axis != 0)
93 {
94 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
95 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
97 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
98 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000099 uint32x4x4_t res =
100 {
101 {
102 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
103 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
104 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
105 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
106 }
107 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000108
109 return res;
110}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100111
112// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
113float32x2_t calculate_min(float32x4_t in)
114{
115 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
116 return wrapper::vpmin(pmin, pmin);
117}
118
119// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
120float32x2_t calculate_max(float32x4_t in)
121{
122 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
123 return wrapper::vpmax(pmax, pmax);
124}
125// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
126int32x2_t calculate_min(int32x4_t in)
127{
128 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
129 return wrapper::vpmin(pmin, pmin);
130}
131
132// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
133int32x2_t calculate_max(int32x4_t in)
134{
135 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
136 return wrapper::vpmax(pmax, pmax);
137}
138
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100139template <typename T>
140uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000141{
142 uint32x4_t res_idx_mask{ 0 };
143 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
144
145 if(op == ReductionOperation::ARG_IDX_MIN)
146 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100147 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000148 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
149 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
150 }
151 else
152 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100153 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100154 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000155 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
156 }
157
158 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
159 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
160 pmin = wrapper::vpmin(pmin, pmin);
161 uint32_t res = wrapper::vgetlane(pmin, 0);
162
163 return (res - 0xFFFFFFFF);
164}
165
Usama Arifa4a08ad2019-05-20 12:38:33 +0100166// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
167inline uint8x8_t calculate_min(uint8x16_t in)
168{
169 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
170 pmin = wrapper::vpmin(pmin, pmin);
171 pmin = wrapper::vpmin(pmin, pmin);
172 return wrapper::vpmin(pmin, pmin);
173}
174// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
175inline uint8x8_t calculate_max(uint8x16_t in)
176{
177 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
178 pmax = wrapper::vpmax(pmax, pmax);
179 pmax = wrapper::vpmax(pmax, pmax);
180 return wrapper::vpmax(pmax, pmax);
181}
182
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000183uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
184{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000185 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000186 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
187 uint8x16_t mask_u8{ 0 };
188 if(op == ReductionOperation::ARG_IDX_MIN)
189 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100190 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000191 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
192 }
193 else
194 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100195 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000196 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
197 }
198
199 // Widen vectors
200 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
201 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
202 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
203 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
204 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
205 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
206 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
207 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
208 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
209 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
210 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
211 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
212 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
213 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
214
215 uint32_t res = 0xFFFFFFFF;
216 int iter = 0;
217 do
218 {
219 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
220 pmin = wrapper::vpmin(pmin, pmin);
221 res = std::min(wrapper::vgetlane(pmin, 0), res);
222 iter++;
223 }
224 while(iter < 4);
225
226 return (res - 0xFFFFFFFF);
227}
228#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
229uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
230{
231 uint32x4x2_t mask{ 0 };
232 uint16x8_t mask_u16{ 0 };
233 if(op == ReductionOperation::ARG_IDX_MIN)
234 {
235 mask_u16 = wrapper::vcgt(b, a);
236 }
237 else
238 {
239 mask_u16 = wrapper::vclt(b, a);
240 }
241 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
242 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
243 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
244 { idx + 4, idx + 5, idx + 6, idx + 7 }
245 }
246 };
247 if(axis != 0)
248 {
249 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
250 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
251 }
252 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
253 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
254 0, 0
255 };
256
257 return res;
258}
259
Usama Arifa4a08ad2019-05-20 12:38:33 +0100260// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
261inline float16x4_t calculate_min(float16x8_t in)
262{
263 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
264 pmin = wrapper::vpmin(pmin, pmin);
265 return wrapper::vpmin(pmin, pmin);
266}
267// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
268inline float16x4_t calculate_max(float16x8_t in)
269{
270 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
271 pmax = wrapper::vpmax(pmax, pmax);
272 return wrapper::vpmax(pmax, pmax);
273}
274
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000275uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
276{
277 uint32x4x2_t res_idx_mask{ 0 };
278 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
279 uint16x8_t mask_u16;
280 if(op == ReductionOperation::ARG_IDX_MIN)
281 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100282 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000283 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
284 }
285 else
286 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100287 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000288 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
289 }
290
291 // Widen vectors
292 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
293 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
294 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
295 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
296 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
297 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
298
299 uint32_t res = 0xFFFFFFFF;
300 int iter = 0;
301 do
302 {
303 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
304 pmin = wrapper::vpmin(pmin, pmin);
305 res = std::min(wrapper::vgetlane(pmin, 0), res);
306 iter++;
307 }
308 while(iter < 2);
309
310 return (res - 0xFFFFFFFF);
311}
312#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
313
Georgios Pinitasd9769582017-08-03 10:19:40 +0100314template <class F>
315class Reducer
316{
317public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000318 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100319 {
320 // Set out window
321 Window out_window(window);
322 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
323
324 // Get first input and output slices
325 Window in_slice = window.first_slice_window_1D();
326 Window out_slice = out_window.first_slice_window_1D();
327
328 do
329 {
330 Iterator in(input, in_slice);
331 Iterator out(output, out_slice);
332
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000333 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100334 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100335 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
336 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000337 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100338 {
339 // Set in window
340 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000341 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100342
343 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000344 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100345
346 // Get first input and output slices
347 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000348 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100349
350 do
351 {
352 Iterator in(input, in_slice);
353 Iterator out(output, out_slice);
354
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000355 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100356 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000357 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000359 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360 {
361 // Set in window
362 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000363 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100364
365 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000366 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100367
368 // Get first input and output slices
369 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000370 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371
372 do
373 {
374 Iterator in(input, in_slice);
375 Iterator out(output, out_slice);
376
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000377 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100378 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000379 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100380 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000381 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100382 {
383 // Set in/out window
384 Window in_window(window);
385 Window out_window(window);
386
387 in_window.set(3, Window::Dimension(0, 1, 1));
388 out_window.set(3, Window::Dimension(0, 1, 1));
389
390 // Get first input and output slices
391 Window in_slice = in_window.first_slice_window_4D();
392 Window out_slice = out_window.first_slice_window_4D();
393
394 do
395 {
396 Iterator in(input, in_slice);
397 Iterator out(output, out_slice);
398
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000399 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100400 }
401 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100402 }
403};
404
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000405template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100406struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100407{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100408 /** NEON vector tag type. */
409 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
410
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000411 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100412 {
413 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000414 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100415 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000416 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100417 case ReductionOperation::ARG_IDX_MAX:
418 case ReductionOperation::ARG_IDX_MIN:
419 case ReductionOperation::MIN:
420 {
421 init_res_value = *reinterpret_cast<T *>(input.ptr());
422 break;
423 }
424 case ReductionOperation::PROD:
425 {
426 init_res_value = static_cast<T>(1.f);
427 break;
428 }
429 default:
430 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000431 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000432 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000433 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100434
435 execute_window_loop(in_slice, [&](const Coordinates & id)
436 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100437 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
438 const auto vec_elements = wrapper::vloadq(in_ptr);
439
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000440 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100441 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000442 case ReductionOperation::SUM_SQUARE:
443 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
444 break;
445 case ReductionOperation::MEAN_SUM:
446 case ReductionOperation::SUM:
447 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
448 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000449 case ReductionOperation::PROD:
450 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
451 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000452 case ReductionOperation::ARG_IDX_MIN:
453 {
454 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100455 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000456 vec_res_value = temp_vec_res_value;
457 break;
458 }
459 case ReductionOperation::ARG_IDX_MAX:
460 {
461 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100462 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000463 vec_res_value = temp_vec_res_value;
464 break;
465 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100466 case ReductionOperation::MIN:
467 {
468 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
469 break;
470 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000471 default:
472 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100473 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100474 },
475 input);
476
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000477 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000478 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000479 case ReductionOperation::SUM:
480 case ReductionOperation::SUM_SQUARE:
481 case ReductionOperation::MEAN_SUM:
482 {
483 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
484 for(int i = 0; i < S / 4; ++i)
485 {
486 carry_res = wrapper::vpadd(carry_res, carry_res);
487 }
488 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100489
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000490 if(op == ReductionOperation::MEAN_SUM)
491 {
492 res /= in_info.dimension(0);
493 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100494
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000495 *(reinterpret_cast<T *>(output.ptr())) = res;
496 break;
497 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000498 case ReductionOperation::PROD:
499 {
500 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
501 T res = 1;
502 for(int i = 0; i < S / 2; ++i)
503 {
504 res *= wrapper::vgetlane(carry_res, i);
505 }
506 *(reinterpret_cast<T *>(output.ptr())) = res;
507 break;
508 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000509 case ReductionOperation::ARG_IDX_MIN:
510 case ReductionOperation::ARG_IDX_MAX:
511 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100512 auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000513 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
514 break;
515 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100516 case ReductionOperation::MIN:
517 {
518 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
519 break;
520 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000521 default:
522 ARM_COMPUTE_ERROR("Not supported");
523 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100524 }
525};
526
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100527struct RedOpX_qasymm8
528{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000529 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100530 {
531 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000532 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
533 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
534 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
535 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100536
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000537 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
538 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
539 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
540 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
541
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000542 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000543
Usama Arifa4a08ad2019-05-20 12:38:33 +0100544 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000545 {
546 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
547 }
548
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000549 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100550 execute_window_loop(in_slice, [&](const Coordinates & id)
551 {
552 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000553 switch(op)
554 {
555 case ReductionOperation::SUM:
556 case ReductionOperation::MEAN_SUM:
557 {
558 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
559 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100560
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000561 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
562 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
563 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
564 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100565
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000566 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
567 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
568 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
569 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
570 break;
571 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000572 case ReductionOperation::PROD:
573 {
574 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
575 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
576
577 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
578 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
579
580 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
581 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
582 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
583 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
584
585 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
586 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
587 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
588 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
589
590 //de-quantize vec_elements
591 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
592 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
593 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
594 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
595
596 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
597 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
598 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
599 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
600 break;
601 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000602 case ReductionOperation::ARG_IDX_MIN:
603 {
604 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
605 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
606 vec_res_value = temp_vec_res_value;
607 break;
608 }
609 case ReductionOperation::ARG_IDX_MAX:
610 {
611 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
612 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
613 vec_res_value = temp_vec_res_value;
614 break;
615 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100616 case ReductionOperation::MIN:
617 {
618 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
619 break;
620 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000621 default:
622 ARM_COMPUTE_ERROR("Not supported");
623 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100624 },
625 input);
626
Usama Arifa4a08ad2019-05-20 12:38:33 +0100627 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100628 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100629 case ReductionOperation::ARG_IDX_MIN:
630 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000631 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100632 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
633 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
634 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000635 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100636 case ReductionOperation::MIN:
637 {
638 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
639 break;
640 }
641 case ReductionOperation::PROD:
642 {
643 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
644 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
645 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000646
Usama Arifa4a08ad2019-05-20 12:38:33 +0100647 float res = wrapper::vgetlane(carry_res, 0);
648 res *= wrapper::vgetlane(carry_res, 1);
649 res *= wrapper::vgetlane(carry_res, 2);
650 res *= wrapper::vgetlane(carry_res, 3);
651
652 //re-quantize result
653 res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
654 *(output.ptr()) = static_cast<uint8_t>(res);
655 break;
656 }
657 default:
658 {
659 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
660 carry_res = wrapper::vadd(carry_res, vec_res_value3);
661 carry_res = wrapper::vadd(carry_res, vec_res_value4);
662
663 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
664 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
665 auto res = wrapper::vgetlane(carry_paddition, 0);
666
667 if(op == ReductionOperation::MEAN_SUM)
668 {
669 res /= in_info.dimension(0);
670 }
671
672 *(output.ptr()) = static_cast<uint8_t>(res);
673 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000674 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100675 }
676};
677
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000678template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100679struct RedOpYZW
680{
681 /** NEON vector tag type. */
682 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000683 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100684
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000685 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100686 {
687 ARM_COMPUTE_UNUSED(out_slice);
688
giuros01154bc1c2019-03-26 17:44:40 +0000689 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100690 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000691 neon_vector vec_res_value = { 0 };
Usama Arifa4a08ad2019-05-20 12:38:33 +0100692 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000693 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100694 case ReductionOperation::ARG_IDX_MAX:
695 case ReductionOperation::ARG_IDX_MIN:
696 case ReductionOperation::MIN:
697 {
698 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
699 break;
700 }
701 case ReductionOperation::PROD:
702 {
703 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
704 break;
705 }
706 default:
707 {
708 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
709 break;
710 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000711 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000712 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000713
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100714 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
715 {
716 T *in_ptr;
717 switch(axis)
718 {
719 case 1:
720 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
721 break;
722 case 2:
723 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
724 break;
725 case 3:
726 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
727 break;
728 default:
729 ARM_COMPUTE_ERROR("Not supported");
730 }
731 const auto vec_elements = wrapper::vloadq(in_ptr);
732
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000733 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100734 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000735 case ReductionOperation::SUM:
736 case ReductionOperation::MEAN_SUM:
737 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
738 break;
739 case ReductionOperation::SUM_SQUARE:
740 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
741 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000742 case ReductionOperation::PROD:
743 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
744 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000745 case ReductionOperation::ARG_IDX_MIN:
746 {
747 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
748 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
749 vec_res_value = temp_vec_res_value;
750 break;
751 }
752 case ReductionOperation::ARG_IDX_MAX:
753 {
754 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
755 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
756 vec_res_value = temp_vec_res_value;
757 break;
758 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100759 case ReductionOperation::MIN:
760 {
761 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
762 break;
763 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000764 default:
765 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100766 }
767 }
768
769 if(op == ReductionOperation::MEAN_SUM)
770 {
771 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000772 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100773 }
774
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000775 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
776 {
777 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
778 }
779 else
780 {
781 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
782 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100783 },
784 input, output);
785 }
786};
787
giuros01154bc1c2019-03-26 17:44:40 +0000788template <typename T, int S, int axis, ReductionOperation op>
789struct RedOpYZW_complex
790{
791 /** NEON vector tag type. */
792 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
793 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
794
795 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
796 {
797 ARM_COMPUTE_UNUSED(out_slice);
798 ARM_COMPUTE_ERROR_ON(axis != 2);
799
800 const size_t stride_z = in_info.strides_in_bytes()[axis];
801
802 execute_window_loop(in_slice, [&](const Coordinates &)
803 {
804 neon_vector vec_res_value_0 = { 0 };
805 neon_vector vec_res_value_1 = { 0 };
806
807 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
808 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
809
810 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
811 {
812 T *in_ptr_0;
813 T *in_ptr_1;
814 switch(axis)
815 {
816 case 2:
817 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
818 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
819 break;
820 default:
821 ARM_COMPUTE_ERROR("Not supported");
822 }
823 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
824 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
825
826 switch(op)
827 {
828 case ReductionOperation::SUM:
829 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
830 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
831 break;
832 default:
833 ARM_COMPUTE_ERROR("Not supported");
834 }
835 }
836
837 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
838 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
839
840 },
841 input, output);
842 }
843};
844
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100845struct RedOpYZW_qasymm8
846{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000847 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100848 {
849 ARM_COMPUTE_UNUSED(out_slice);
850
giuros01154bc1c2019-03-26 17:44:40 +0000851 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100852 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000853 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000854 auto vec_res_value1 = vdupq_n_u32(0);
855 auto vec_res_value2 = vdupq_n_u32(0);
856 auto vec_res_value3 = vdupq_n_u32(0);
857 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000858
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000859 auto vec_res_value1_f = vdupq_n_f32(1);
860 auto vec_res_value2_f = vdupq_n_f32(1);
861 auto vec_res_value3_f = vdupq_n_f32(1);
862 auto vec_res_value4_f = vdupq_n_f32(1);
863
864 auto vec_res_value = wrapper::vloadq(input.ptr());
865
866 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100867 {
868 uint8_t *in_ptr;
869 switch(axis)
870 {
871 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000872 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100873 break;
874 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000875 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100876 break;
877 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000878 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100879 break;
880 default:
881 ARM_COMPUTE_ERROR("Not supported");
882 }
883 const auto vec_elements = wrapper::vloadq(in_ptr);
884
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000885 switch(op)
886 {
887 case ReductionOperation::SUM:
888 case ReductionOperation::MEAN_SUM:
889 {
890 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
891 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100892
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000893 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
894 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
895 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
896 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100897
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000898 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
899 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
900 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
901 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
902 break;
903 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000904 case ReductionOperation::PROD:
905 {
906 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
907 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
908
909 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
910 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
911
912 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
913 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
914 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
915 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
916
917 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
918 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
919 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
920 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
921
922 //de-quantize vec_elements
923 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
924 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
925 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
926 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
927
928 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
929 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
930 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
931 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
932 break;
933 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000934 case ReductionOperation::ARG_IDX_MIN:
935 {
936 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000937 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000938 vec_res_value = temp_vec_res_value;
939 break;
940 }
941 case ReductionOperation::ARG_IDX_MAX:
942 {
943 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000944 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000945 vec_res_value = temp_vec_res_value;
946 break;
947 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100948 case ReductionOperation::MIN:
949 {
950 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
951 break;
952 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000953 default:
954 ARM_COMPUTE_ERROR("Not supported");
955 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100956 }
957
958 if(op == ReductionOperation::MEAN_SUM)
959 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000960 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
961 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
962 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
963 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
964 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100965
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000966 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
967 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
968 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
969 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
970 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000971 else if(op == ReductionOperation::PROD)
972 {
973 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
974 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
975
976 //re-quantize
977 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
978 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
979 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
980 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
981
982 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
983 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
984 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
985 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
986 }
987
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000988 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
989 {
990 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
991 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
992 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
993 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
994 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100995 else if(op == ReductionOperation::ARG_IDX_MIN)
996 {
997 wrapper::vstore(output.ptr(), vec_res_value);
998 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000999 else
1000 {
1001 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1002 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1003 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1004 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001005 }
1006
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001007 },
1008 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001009 }
1010};
1011
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001012void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001013{
giuros01154bc1c2019-03-26 17:44:40 +00001014 const bool is_complex = (input->info()->num_channels() == 2);
1015
1016 if(is_complex)
1017 {
1018 switch(axis)
1019 {
1020 case 2:
1021 switch(input->info()->data_type())
1022 {
1023 case DataType::F32:
1024 switch(op)
1025 {
1026 case ReductionOperation::SUM:
1027 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1028 default:
1029 ARM_COMPUTE_ERROR("Not supported");
1030 }
1031 default:
1032 ARM_COMPUTE_ERROR("Not supported");
1033 }
1034 default:
1035 ARM_COMPUTE_ERROR("Not supported");
1036 }
1037 }
1038
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001039 switch(axis)
1040 {
1041 case 0:
1042 switch(input->info()->data_type())
1043 {
1044 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001045 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001046#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1047 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001048 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001049#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1050 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001051 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001052 case DataType::S32:
1053 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001054 default:
1055 ARM_COMPUTE_ERROR("Not supported");
1056 }
1057 case 1:
1058 switch(input->info()->data_type())
1059 {
1060 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001061 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001062#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1063 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001064 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001065#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1066 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001067 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001068 case DataType::S32:
1069 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001070 default:
1071 ARM_COMPUTE_ERROR("Not supported");
1072 }
1073 case 2:
1074 switch(input->info()->data_type())
1075 {
1076 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001077 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001078#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1079 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001080 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001081#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1082 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001083 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001084 case DataType::S32:
1085 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001086 default:
1087 ARM_COMPUTE_ERROR("Not supported");
1088 }
1089 case 3:
1090 switch(input->info()->data_type())
1091 {
1092 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001093 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001094#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1095 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001096 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001097#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1098 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001099 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001100 case DataType::S32:
1101 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001102 default:
1103 ARM_COMPUTE_ERROR("Not supported");
1104 }
1105 default:
1106 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1107 }
1108}
John Richardson73d4aef2018-05-08 14:34:33 +01001109
1110Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1111{
1112 ARM_COMPUTE_UNUSED(op);
1113
1114 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001115 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001116
1117 if(input->num_channels() == 1)
1118 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001119 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001120 }
1121 else
1122 {
1123 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1124 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1125 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1126 }
John Richardson73d4aef2018-05-08 14:34:33 +01001127
1128 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001129 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001130
1131 if(output->total_size() != 0)
1132 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001133 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1134 if(!is_arg_min_max)
1135 {
1136 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001137 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001138 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001139 }
1140 else
1141 {
1142 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
1143 }
John Richardson73d4aef2018-05-08 14:34:33 +01001144
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001145 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001146 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1147 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1148 }
1149
1150 return Status{};
1151}
1152
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001153std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001154{
1155 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001156 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001157
1158 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001159 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1160 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
giuros01154bc1c2019-03-26 17:44:40 +00001161 auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001162
1163 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
1164
1165 // Configure kernel window
1166 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
1167 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
1168 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
1169
1170 bool window_changed = update_window_and_padding(win, input_access, output_access);
1171 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1172
1173 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1174
1175 return std::make_tuple(err, win);
1176}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001177} // namespace
1178
1179NEReductionOperationKernel::NEReductionOperationKernel()
1180 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
1181{
1182}
1183
1184BorderSize NEReductionOperationKernel::border_size() const
1185{
1186 return _border_size;
1187}
1188
1189void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1190{
1191 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001192
John Richardson73d4aef2018-05-08 14:34:33 +01001193 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001194
1195 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
1196
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001197 _input = input;
1198 _output = output;
1199 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
1200 _op = op;
1201 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001202
1203 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001204 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001205
John Richardson73d4aef2018-05-08 14:34:33 +01001206 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001207
John Richardson73d4aef2018-05-08 14:34:33 +01001208 INEKernel::configure(std::get<1>(win_config));
1209}
1210
1211Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1212{
1213 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001214 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001215
1216 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001217}
1218
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001219void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001220{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001221 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001222 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1223 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1224
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001225 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001226}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001227} // namespace arm_compute