blob: e6edf22083a5dac3c7e40c151e606cb39c49cd04 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010034#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036
Michalis Spyroubcf8a962018-10-12 10:51:31 +010037#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038#include <arm_neon.h>
39
Michalis Spyroubcf8a962018-10-12 10:51:31 +010040namespace arm_compute
41{
Georgios Pinitasd9769582017-08-03 10:19:40 +010042namespace
43{
Michalis Spyroub9626ab2019-05-13 17:41:01 +010044template <typename T>
45uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000046{
47 uint32x4_t mask{ 0 };
48 if(op == ReductionOperation::ARG_IDX_MIN)
49 {
50 mask = wrapper::vcgt(b, a);
51 }
52 else
53 {
54 mask = wrapper::vclt(b, a);
55 }
56
57 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
58 if(axis != 0)
59 {
60 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
61 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000062 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000063
64 return res;
65}
66
67uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
68{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000069 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000070 uint8x16_t mask_u8{ 0 };
71 if(op == ReductionOperation::ARG_IDX_MIN)
72 {
73 mask_u8 = wrapper::vcgt(b, a);
74 }
75 else
76 {
77 mask_u8 = wrapper::vclt(b, a);
78 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000079 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
80 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
81 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
82 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
83 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
84 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
85
Michalis Spyrouaea14c62019-01-03 11:10:25 +000086 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
87 { idx + 4, idx + 5, idx + 6, idx + 7 },
88 { idx + 8, idx + 9, idx + 10, idx + 11 },
89 { idx + 12, idx + 13, idx + 14, idx + 15 }
90 }
91 };
92 if(axis != 0)
93 {
94 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
95 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
97 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
98 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000099 uint32x4x4_t res =
100 {
101 {
102 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
103 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
104 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
105 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
106 }
107 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000108
109 return res;
110}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100111
112// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
113float32x2_t calculate_min(float32x4_t in)
114{
115 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
116 return wrapper::vpmin(pmin, pmin);
117}
118
119// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
120float32x2_t calculate_max(float32x4_t in)
121{
122 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
123 return wrapper::vpmax(pmax, pmax);
124}
125// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
126int32x2_t calculate_min(int32x4_t in)
127{
128 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
129 return wrapper::vpmin(pmin, pmin);
130}
131
132// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
133int32x2_t calculate_max(int32x4_t in)
134{
135 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
136 return wrapper::vpmax(pmax, pmax);
137}
138
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100139template <typename T>
140uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000141{
142 uint32x4_t res_idx_mask{ 0 };
143 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
144
145 if(op == ReductionOperation::ARG_IDX_MIN)
146 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100147 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000148 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
149 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
150 }
151 else
152 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100153 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100154 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000155 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
156 }
157
158 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
159 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
160 pmin = wrapper::vpmin(pmin, pmin);
161 uint32_t res = wrapper::vgetlane(pmin, 0);
162
163 return (res - 0xFFFFFFFF);
164}
165
Usama Arifa4a08ad2019-05-20 12:38:33 +0100166// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
167inline uint8x8_t calculate_min(uint8x16_t in)
168{
169 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
170 pmin = wrapper::vpmin(pmin, pmin);
171 pmin = wrapper::vpmin(pmin, pmin);
172 return wrapper::vpmin(pmin, pmin);
173}
174// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
175inline uint8x8_t calculate_max(uint8x16_t in)
176{
177 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
178 pmax = wrapper::vpmax(pmax, pmax);
179 pmax = wrapper::vpmax(pmax, pmax);
180 return wrapper::vpmax(pmax, pmax);
181}
182
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000183uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
184{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000185 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000186 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
187 uint8x16_t mask_u8{ 0 };
188 if(op == ReductionOperation::ARG_IDX_MIN)
189 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100190 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000191 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
192 }
193 else
194 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100195 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000196 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
197 }
198
199 // Widen vectors
200 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
201 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
202 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
203 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
204 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
205 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
206 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
207 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
208 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
209 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
210 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
211 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
212 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
213 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
214
215 uint32_t res = 0xFFFFFFFF;
216 int iter = 0;
217 do
218 {
219 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
220 pmin = wrapper::vpmin(pmin, pmin);
221 res = std::min(wrapper::vgetlane(pmin, 0), res);
222 iter++;
223 }
224 while(iter < 4);
225
226 return (res - 0xFFFFFFFF);
227}
228#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
229uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
230{
231 uint32x4x2_t mask{ 0 };
232 uint16x8_t mask_u16{ 0 };
233 if(op == ReductionOperation::ARG_IDX_MIN)
234 {
235 mask_u16 = wrapper::vcgt(b, a);
236 }
237 else
238 {
239 mask_u16 = wrapper::vclt(b, a);
240 }
241 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
242 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
243 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
244 { idx + 4, idx + 5, idx + 6, idx + 7 }
245 }
246 };
247 if(axis != 0)
248 {
249 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
250 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
251 }
252 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
253 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
254 0, 0
255 };
256
257 return res;
258}
259
Usama Arifa4a08ad2019-05-20 12:38:33 +0100260// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
261inline float16x4_t calculate_min(float16x8_t in)
262{
263 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
264 pmin = wrapper::vpmin(pmin, pmin);
265 return wrapper::vpmin(pmin, pmin);
266}
267// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
268inline float16x4_t calculate_max(float16x8_t in)
269{
270 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
271 pmax = wrapper::vpmax(pmax, pmax);
272 return wrapper::vpmax(pmax, pmax);
273}
274
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000275uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
276{
277 uint32x4x2_t res_idx_mask{ 0 };
278 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
279 uint16x8_t mask_u16;
280 if(op == ReductionOperation::ARG_IDX_MIN)
281 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100282 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000283 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
284 }
285 else
286 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100287 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000288 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
289 }
290
291 // Widen vectors
292 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
293 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
294 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
295 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
296 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
297 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
298
299 uint32_t res = 0xFFFFFFFF;
300 int iter = 0;
301 do
302 {
303 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
304 pmin = wrapper::vpmin(pmin, pmin);
305 res = std::min(wrapper::vgetlane(pmin, 0), res);
306 iter++;
307 }
308 while(iter < 2);
309
310 return (res - 0xFFFFFFFF);
311}
312#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
313
Georgios Pinitasd9769582017-08-03 10:19:40 +0100314template <class F>
315class Reducer
316{
317public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000318 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100319 {
320 // Set out window
321 Window out_window(window);
322 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
323
324 // Get first input and output slices
325 Window in_slice = window.first_slice_window_1D();
326 Window out_slice = out_window.first_slice_window_1D();
327
328 do
329 {
330 Iterator in(input, in_slice);
331 Iterator out(output, out_slice);
332
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000333 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100334 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100335 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
336 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000337 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100338 {
339 // Set in window
340 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000341 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100342
343 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000344 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100345
346 // Get first input and output slices
347 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000348 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100349
350 do
351 {
352 Iterator in(input, in_slice);
353 Iterator out(output, out_slice);
354
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000355 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100356 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000357 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000359 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360 {
361 // Set in window
362 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000363 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100364
365 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000366 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100367
368 // Get first input and output slices
369 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000370 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371
372 do
373 {
374 Iterator in(input, in_slice);
375 Iterator out(output, out_slice);
376
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000377 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100378 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000379 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100380 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000381 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100382 {
383 // Set in/out window
384 Window in_window(window);
385 Window out_window(window);
386
387 in_window.set(3, Window::Dimension(0, 1, 1));
388 out_window.set(3, Window::Dimension(0, 1, 1));
389
390 // Get first input and output slices
391 Window in_slice = in_window.first_slice_window_4D();
392 Window out_slice = out_window.first_slice_window_4D();
393
394 do
395 {
396 Iterator in(input, in_slice);
397 Iterator out(output, out_slice);
398
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000399 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100400 }
401 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100402 }
403};
404
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000405template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100406struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100407{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100408 /** NEON vector tag type. */
409 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
410
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000411 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100412 {
413 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000414 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100415 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000416 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100417 case ReductionOperation::ARG_IDX_MAX:
418 case ReductionOperation::ARG_IDX_MIN:
419 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100420 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100421 {
422 init_res_value = *reinterpret_cast<T *>(input.ptr());
423 break;
424 }
425 case ReductionOperation::PROD:
426 {
427 init_res_value = static_cast<T>(1.f);
428 break;
429 }
430 default:
431 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000432 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000433 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000434 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100435
436 execute_window_loop(in_slice, [&](const Coordinates & id)
437 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100438 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
439 const auto vec_elements = wrapper::vloadq(in_ptr);
440
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000441 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100442 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000443 case ReductionOperation::SUM_SQUARE:
444 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
445 break;
446 case ReductionOperation::MEAN_SUM:
447 case ReductionOperation::SUM:
448 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
449 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000450 case ReductionOperation::PROD:
451 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
452 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000453 case ReductionOperation::ARG_IDX_MIN:
454 {
455 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100456 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000457 vec_res_value = temp_vec_res_value;
458 break;
459 }
460 case ReductionOperation::ARG_IDX_MAX:
461 {
462 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100463 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000464 vec_res_value = temp_vec_res_value;
465 break;
466 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100467 case ReductionOperation::MIN:
468 {
469 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
470 break;
471 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100472 case ReductionOperation::MAX:
473 {
474 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
475 break;
476 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000477 default:
478 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100479 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100480 },
481 input);
482
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000483 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000484 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000485 case ReductionOperation::SUM:
486 case ReductionOperation::SUM_SQUARE:
487 case ReductionOperation::MEAN_SUM:
488 {
489 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
490 for(int i = 0; i < S / 4; ++i)
491 {
492 carry_res = wrapper::vpadd(carry_res, carry_res);
493 }
494 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100495
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000496 if(op == ReductionOperation::MEAN_SUM)
497 {
498 res /= in_info.dimension(0);
499 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100500
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000501 *(reinterpret_cast<T *>(output.ptr())) = res;
502 break;
503 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000504 case ReductionOperation::PROD:
505 {
506 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
507 T res = 1;
508 for(int i = 0; i < S / 2; ++i)
509 {
510 res *= wrapper::vgetlane(carry_res, i);
511 }
512 *(reinterpret_cast<T *>(output.ptr())) = res;
513 break;
514 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000515 case ReductionOperation::ARG_IDX_MIN:
516 case ReductionOperation::ARG_IDX_MAX:
517 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100518 auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000519 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
520 break;
521 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100522 case ReductionOperation::MIN:
523 {
524 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
525 break;
526 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100527 case ReductionOperation::MAX:
528 {
529 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0);
530 break;
531 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000532 default:
533 ARM_COMPUTE_ERROR("Not supported");
534 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100535 }
536};
537
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100538struct RedOpX_qasymm8
539{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000540 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100541 {
542 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000543 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
544 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
545 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
546 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100547
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000548 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
549 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
550 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
551 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
552
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000553 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000554
Usama Arif28f0dd92019-05-20 13:44:34 +0100555 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000556 {
557 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
558 }
559
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000560 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100561 execute_window_loop(in_slice, [&](const Coordinates & id)
562 {
563 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000564 switch(op)
565 {
566 case ReductionOperation::SUM:
567 case ReductionOperation::MEAN_SUM:
568 {
569 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
570 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100571
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000572 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
573 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
574 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
575 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100576
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000577 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
578 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
579 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
580 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
581 break;
582 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000583 case ReductionOperation::PROD:
584 {
585 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
586 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
587
588 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
589 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
590
591 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
592 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
593 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
594 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
595
596 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
597 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
598 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
599 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
600
601 //de-quantize vec_elements
602 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
603 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
604 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
605 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
606
607 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
608 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
609 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
610 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
611 break;
612 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000613 case ReductionOperation::ARG_IDX_MIN:
614 {
615 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
616 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
617 vec_res_value = temp_vec_res_value;
618 break;
619 }
620 case ReductionOperation::ARG_IDX_MAX:
621 {
622 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
623 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
624 vec_res_value = temp_vec_res_value;
625 break;
626 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100627 case ReductionOperation::MIN:
628 {
629 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
630 break;
631 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100632 case ReductionOperation::MAX:
633 {
634 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
635 break;
636 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000637 default:
638 ARM_COMPUTE_ERROR("Not supported");
639 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100640 },
641 input);
642
Usama Arifa4a08ad2019-05-20 12:38:33 +0100643 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100644 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100645 case ReductionOperation::ARG_IDX_MIN:
646 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000647 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100648 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
649 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
650 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000651 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100652 case ReductionOperation::MIN:
653 {
654 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
655 break;
656 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100657 case ReductionOperation::MAX:
658 {
659 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
660 break;
661 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100662 case ReductionOperation::PROD:
663 {
664 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
665 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
666 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000667
Usama Arifa4a08ad2019-05-20 12:38:33 +0100668 float res = wrapper::vgetlane(carry_res, 0);
669 res *= wrapper::vgetlane(carry_res, 1);
670 res *= wrapper::vgetlane(carry_res, 2);
671 res *= wrapper::vgetlane(carry_res, 3);
672
673 //re-quantize result
674 res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
675 *(output.ptr()) = static_cast<uint8_t>(res);
676 break;
677 }
678 default:
679 {
680 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
681 carry_res = wrapper::vadd(carry_res, vec_res_value3);
682 carry_res = wrapper::vadd(carry_res, vec_res_value4);
683
684 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
685 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
686 auto res = wrapper::vgetlane(carry_paddition, 0);
687
688 if(op == ReductionOperation::MEAN_SUM)
689 {
690 res /= in_info.dimension(0);
691 }
692
693 *(output.ptr()) = static_cast<uint8_t>(res);
694 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000695 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100696 }
697};
698
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000699template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100700struct RedOpYZW
701{
702 /** NEON vector tag type. */
703 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000704 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100705
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000706 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100707 {
708 ARM_COMPUTE_UNUSED(out_slice);
709
giuros01154bc1c2019-03-26 17:44:40 +0000710 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100711 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000712 neon_vector vec_res_value = { 0 };
Usama Arifa4a08ad2019-05-20 12:38:33 +0100713 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000714 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100715 case ReductionOperation::ARG_IDX_MAX:
716 case ReductionOperation::ARG_IDX_MIN:
717 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100718 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100719 {
720 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
721 break;
722 }
723 case ReductionOperation::PROD:
724 {
725 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
726 break;
727 }
728 default:
729 {
730 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
731 break;
732 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000733 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000734 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000735
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100736 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
737 {
738 T *in_ptr;
739 switch(axis)
740 {
741 case 1:
742 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
743 break;
744 case 2:
745 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
746 break;
747 case 3:
748 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
749 break;
750 default:
751 ARM_COMPUTE_ERROR("Not supported");
752 }
753 const auto vec_elements = wrapper::vloadq(in_ptr);
754
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000755 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100756 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000757 case ReductionOperation::SUM:
758 case ReductionOperation::MEAN_SUM:
759 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
760 break;
761 case ReductionOperation::SUM_SQUARE:
762 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
763 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000764 case ReductionOperation::PROD:
765 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
766 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000767 case ReductionOperation::ARG_IDX_MIN:
768 {
769 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
770 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
771 vec_res_value = temp_vec_res_value;
772 break;
773 }
774 case ReductionOperation::ARG_IDX_MAX:
775 {
776 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
777 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
778 vec_res_value = temp_vec_res_value;
779 break;
780 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100781 case ReductionOperation::MIN:
782 {
783 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
784 break;
785 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100786 case ReductionOperation::MAX:
787 {
788 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
789 break;
790 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000791 default:
792 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100793 }
794 }
795
796 if(op == ReductionOperation::MEAN_SUM)
797 {
798 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000799 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100800 }
801
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000802 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
803 {
804 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
805 }
806 else
807 {
808 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
809 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100810 },
811 input, output);
812 }
813};
814
giuros01154bc1c2019-03-26 17:44:40 +0000815template <typename T, int S, int axis, ReductionOperation op>
816struct RedOpYZW_complex
817{
818 /** NEON vector tag type. */
819 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
820 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
821
822 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
823 {
824 ARM_COMPUTE_UNUSED(out_slice);
825 ARM_COMPUTE_ERROR_ON(axis != 2);
826
827 const size_t stride_z = in_info.strides_in_bytes()[axis];
828
829 execute_window_loop(in_slice, [&](const Coordinates &)
830 {
831 neon_vector vec_res_value_0 = { 0 };
832 neon_vector vec_res_value_1 = { 0 };
833
834 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
835 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
836
837 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
838 {
839 T *in_ptr_0;
840 T *in_ptr_1;
841 switch(axis)
842 {
843 case 2:
844 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
845 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
846 break;
847 default:
848 ARM_COMPUTE_ERROR("Not supported");
849 }
850 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
851 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
852
853 switch(op)
854 {
855 case ReductionOperation::SUM:
856 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
857 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
858 break;
859 default:
860 ARM_COMPUTE_ERROR("Not supported");
861 }
862 }
863
864 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
865 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
866
867 },
868 input, output);
869 }
870};
871
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100872struct RedOpYZW_qasymm8
873{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000874 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100875 {
876 ARM_COMPUTE_UNUSED(out_slice);
877
giuros01154bc1c2019-03-26 17:44:40 +0000878 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100879 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000880 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000881 auto vec_res_value1 = vdupq_n_u32(0);
882 auto vec_res_value2 = vdupq_n_u32(0);
883 auto vec_res_value3 = vdupq_n_u32(0);
884 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000885
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000886 auto vec_res_value1_f = vdupq_n_f32(1);
887 auto vec_res_value2_f = vdupq_n_f32(1);
888 auto vec_res_value3_f = vdupq_n_f32(1);
889 auto vec_res_value4_f = vdupq_n_f32(1);
890
891 auto vec_res_value = wrapper::vloadq(input.ptr());
892
893 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100894 {
895 uint8_t *in_ptr;
896 switch(axis)
897 {
898 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000899 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100900 break;
901 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000902 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100903 break;
904 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000905 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100906 break;
907 default:
908 ARM_COMPUTE_ERROR("Not supported");
909 }
910 const auto vec_elements = wrapper::vloadq(in_ptr);
911
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000912 switch(op)
913 {
914 case ReductionOperation::SUM:
915 case ReductionOperation::MEAN_SUM:
916 {
917 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
918 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100919
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000920 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
921 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
922 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
923 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100924
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000925 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
926 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
927 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
928 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
929 break;
930 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000931 case ReductionOperation::PROD:
932 {
933 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
934 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
935
936 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
937 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
938
939 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
940 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
941 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
942 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
943
944 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
945 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
946 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
947 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
948
949 //de-quantize vec_elements
950 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
951 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
952 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
953 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
954
955 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
956 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
957 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
958 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
959 break;
960 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000961 case ReductionOperation::ARG_IDX_MIN:
962 {
963 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000964 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000965 vec_res_value = temp_vec_res_value;
966 break;
967 }
968 case ReductionOperation::ARG_IDX_MAX:
969 {
970 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000971 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000972 vec_res_value = temp_vec_res_value;
973 break;
974 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100975 case ReductionOperation::MIN:
976 {
977 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
978 break;
979 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100980 case ReductionOperation::MAX:
981 {
982 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
983 break;
984 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000985 default:
986 ARM_COMPUTE_ERROR("Not supported");
987 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100988 }
989
990 if(op == ReductionOperation::MEAN_SUM)
991 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000992 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
993 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
994 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
995 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
996 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100997
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000998 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
999 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1000 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1001 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1002 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001003 else if(op == ReductionOperation::PROD)
1004 {
1005 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
1006 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
1007
1008 //re-quantize
1009 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1010 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1011 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1012 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1013
1014 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
1015 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1016 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1017 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1018 }
1019
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001020 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1021 {
1022 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
1023 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
1024 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
1025 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
1026 }
Usama Arifa4a08ad2019-05-20 12:38:33 +01001027 else if(op == ReductionOperation::ARG_IDX_MIN)
1028 {
1029 wrapper::vstore(output.ptr(), vec_res_value);
1030 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001031 else
1032 {
1033 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1034 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1035 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1036 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001037 }
1038
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001039 },
1040 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001041 }
1042};
1043
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001044void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001045{
giuros01154bc1c2019-03-26 17:44:40 +00001046 const bool is_complex = (input->info()->num_channels() == 2);
1047
1048 if(is_complex)
1049 {
1050 switch(axis)
1051 {
1052 case 2:
1053 switch(input->info()->data_type())
1054 {
1055 case DataType::F32:
1056 switch(op)
1057 {
1058 case ReductionOperation::SUM:
1059 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1060 default:
1061 ARM_COMPUTE_ERROR("Not supported");
1062 }
1063 default:
1064 ARM_COMPUTE_ERROR("Not supported");
1065 }
1066 default:
1067 ARM_COMPUTE_ERROR("Not supported");
1068 }
1069 }
1070
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001071 switch(axis)
1072 {
1073 case 0:
1074 switch(input->info()->data_type())
1075 {
1076 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001077 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001078#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1079 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001080 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001081#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1082 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001083 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001084 case DataType::S32:
1085 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001086 default:
1087 ARM_COMPUTE_ERROR("Not supported");
1088 }
1089 case 1:
1090 switch(input->info()->data_type())
1091 {
1092 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001093 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001094#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1095 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001096 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001097#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1098 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001099 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001100 case DataType::S32:
1101 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001102 default:
1103 ARM_COMPUTE_ERROR("Not supported");
1104 }
1105 case 2:
1106 switch(input->info()->data_type())
1107 {
1108 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001109 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001110#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1111 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001112 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001113#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1114 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001115 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001116 case DataType::S32:
1117 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001118 default:
1119 ARM_COMPUTE_ERROR("Not supported");
1120 }
1121 case 3:
1122 switch(input->info()->data_type())
1123 {
1124 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001125 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001126#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1127 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001128 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001129#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1130 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001131 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001132 case DataType::S32:
1133 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001134 default:
1135 ARM_COMPUTE_ERROR("Not supported");
1136 }
1137 default:
1138 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1139 }
1140}
John Richardson73d4aef2018-05-08 14:34:33 +01001141
1142Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1143{
1144 ARM_COMPUTE_UNUSED(op);
1145
1146 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001147 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001148
1149 if(input->num_channels() == 1)
1150 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001151 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001152 }
1153 else
1154 {
1155 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1156 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1157 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1158 }
John Richardson73d4aef2018-05-08 14:34:33 +01001159
1160 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001161 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001162
1163 if(output->total_size() != 0)
1164 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001165 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1166 if(!is_arg_min_max)
1167 {
1168 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001169 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001170 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001171 }
1172 else
1173 {
1174 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
1175 }
John Richardson73d4aef2018-05-08 14:34:33 +01001176
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001177 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001178 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1179 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1180 }
1181
1182 return Status{};
1183}
1184
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001185std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001186{
1187 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001188 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001189
1190 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001191 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1192 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
giuros01154bc1c2019-03-26 17:44:40 +00001193 auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001194
1195 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
1196
1197 // Configure kernel window
1198 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
1199 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
1200 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
1201
1202 bool window_changed = update_window_and_padding(win, input_access, output_access);
1203 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1204
1205 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1206
1207 return std::make_tuple(err, win);
1208}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001209} // namespace
1210
1211NEReductionOperationKernel::NEReductionOperationKernel()
1212 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
1213{
1214}
1215
1216BorderSize NEReductionOperationKernel::border_size() const
1217{
1218 return _border_size;
1219}
1220
1221void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1222{
1223 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001224
John Richardson73d4aef2018-05-08 14:34:33 +01001225 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001226
1227 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
1228
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001229 _input = input;
1230 _output = output;
1231 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
1232 _op = op;
1233 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001234
1235 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001236 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001237
John Richardson73d4aef2018-05-08 14:34:33 +01001238 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001239
John Richardson73d4aef2018-05-08 14:34:33 +01001240 INEKernel::configure(std::get<1>(win_config));
1241}
1242
1243Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1244{
1245 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001246 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001247
1248 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001249}
1250
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001251void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001252{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001253 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001254 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1255 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1256
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001257 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001258}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001259} // namespace arm_compute