blob: 67ccc5d736fc1533dc67aabc2205cc77284f5670 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010034#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036
Michalis Spyroubcf8a962018-10-12 10:51:31 +010037#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038#include <arm_neon.h>
39
Michalis Spyroubcf8a962018-10-12 10:51:31 +010040namespace arm_compute
41{
Georgios Pinitasd9769582017-08-03 10:19:40 +010042namespace
43{
Michalis Spyroub9626ab2019-05-13 17:41:01 +010044template <typename T>
45uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000046{
47 uint32x4_t mask{ 0 };
48 if(op == ReductionOperation::ARG_IDX_MIN)
49 {
50 mask = wrapper::vcgt(b, a);
51 }
52 else
53 {
54 mask = wrapper::vclt(b, a);
55 }
56
57 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
58 if(axis != 0)
59 {
60 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
61 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000062 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000063
64 return res;
65}
66
Georgios Pinitasfad18382019-06-05 15:12:22 +010067template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +000068uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
69{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000070 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000071 uint8x16_t mask_u8{ 0 };
72 if(op == ReductionOperation::ARG_IDX_MIN)
73 {
74 mask_u8 = wrapper::vcgt(b, a);
75 }
76 else
77 {
78 mask_u8 = wrapper::vclt(b, a);
79 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000080 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
81 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
82 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
83 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
84 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
85 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
86
Michalis Spyrouaea14c62019-01-03 11:10:25 +000087 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
88 { idx + 4, idx + 5, idx + 6, idx + 7 },
89 { idx + 8, idx + 9, idx + 10, idx + 11 },
90 { idx + 12, idx + 13, idx + 14, idx + 15 }
91 }
92 };
93 if(axis != 0)
94 {
95 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
97 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
98 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
99 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000100 uint32x4x4_t res =
101 {
102 {
103 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
104 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
105 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
106 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
107 }
108 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000109
110 return res;
111}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100112
113// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
114float32x2_t calculate_min(float32x4_t in)
115{
116 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
117 return wrapper::vpmin(pmin, pmin);
118}
119
120// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
121float32x2_t calculate_max(float32x4_t in)
122{
123 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
124 return wrapper::vpmax(pmax, pmax);
125}
126// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
127int32x2_t calculate_min(int32x4_t in)
128{
129 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
130 return wrapper::vpmin(pmin, pmin);
131}
132
133// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
134int32x2_t calculate_max(int32x4_t in)
135{
136 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
137 return wrapper::vpmax(pmax, pmax);
138}
139
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100140template <typename T>
141uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000142{
143 uint32x4_t res_idx_mask{ 0 };
144 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
145
146 if(op == ReductionOperation::ARG_IDX_MIN)
147 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100148 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000149 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
150 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
151 }
152 else
153 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100154 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100155 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000156 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
157 }
158
159 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
160 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
161 pmin = wrapper::vpmin(pmin, pmin);
162 uint32_t res = wrapper::vgetlane(pmin, 0);
163
164 return (res - 0xFFFFFFFF);
165}
166
Usama Arifa4a08ad2019-05-20 12:38:33 +0100167// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
168inline uint8x8_t calculate_min(uint8x16_t in)
169{
170 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
171 pmin = wrapper::vpmin(pmin, pmin);
172 pmin = wrapper::vpmin(pmin, pmin);
173 return wrapper::vpmin(pmin, pmin);
174}
175// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
176inline uint8x8_t calculate_max(uint8x16_t in)
177{
178 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
179 pmax = wrapper::vpmax(pmax, pmax);
180 pmax = wrapper::vpmax(pmax, pmax);
181 return wrapper::vpmax(pmax, pmax);
182}
183
Usama Arif0a5a57a2019-05-23 14:20:33 +0100184template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000185uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
186{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000187 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000188 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
189 uint8x16_t mask_u8{ 0 };
190 if(op == ReductionOperation::ARG_IDX_MIN)
191 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100192 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000193 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
194 }
195 else
196 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100197 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000198 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
199 }
200
201 // Widen vectors
202 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
203 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
204 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
205 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
206 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
207 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
208 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
209 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
210 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
211 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
212 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
213 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
214 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
215 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
216
217 uint32_t res = 0xFFFFFFFF;
218 int iter = 0;
219 do
220 {
221 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
222 pmin = wrapper::vpmin(pmin, pmin);
223 res = std::min(wrapper::vgetlane(pmin, 0), res);
224 iter++;
225 }
226 while(iter < 4);
227
228 return (res - 0xFFFFFFFF);
229}
230#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100231template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000232uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
233{
234 uint32x4x2_t mask{ 0 };
235 uint16x8_t mask_u16{ 0 };
236 if(op == ReductionOperation::ARG_IDX_MIN)
237 {
238 mask_u16 = wrapper::vcgt(b, a);
239 }
240 else
241 {
242 mask_u16 = wrapper::vclt(b, a);
243 }
244 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
245 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
246 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
247 { idx + 4, idx + 5, idx + 6, idx + 7 }
248 }
249 };
250 if(axis != 0)
251 {
252 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
253 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
254 }
255 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
256 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
257 0, 0
258 };
259
260 return res;
261}
262
Usama Arifa4a08ad2019-05-20 12:38:33 +0100263// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
264inline float16x4_t calculate_min(float16x8_t in)
265{
266 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
267 pmin = wrapper::vpmin(pmin, pmin);
268 return wrapper::vpmin(pmin, pmin);
269}
270// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
271inline float16x4_t calculate_max(float16x8_t in)
272{
273 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
274 pmax = wrapper::vpmax(pmax, pmax);
275 return wrapper::vpmax(pmax, pmax);
276}
277
Usama Arif0a5a57a2019-05-23 14:20:33 +0100278template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000279uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
280{
281 uint32x4x2_t res_idx_mask{ 0 };
282 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
283 uint16x8_t mask_u16;
284 if(op == ReductionOperation::ARG_IDX_MIN)
285 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100286 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000287 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
288 }
289 else
290 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100291 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000292 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
293 }
294
295 // Widen vectors
296 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
297 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
298 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
299 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
300 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
301 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
302
303 uint32_t res = 0xFFFFFFFF;
304 int iter = 0;
305 do
306 {
307 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
308 pmin = wrapper::vpmin(pmin, pmin);
309 res = std::min(wrapper::vgetlane(pmin, 0), res);
310 iter++;
311 }
312 while(iter < 2);
313
314 return (res - 0xFFFFFFFF);
315}
316#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
317
Georgios Pinitasd9769582017-08-03 10:19:40 +0100318template <class F>
319class Reducer
320{
321public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000322 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100323 {
324 // Set out window
325 Window out_window(window);
326 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
327
328 // Get first input and output slices
329 Window in_slice = window.first_slice_window_1D();
330 Window out_slice = out_window.first_slice_window_1D();
331
332 do
333 {
334 Iterator in(input, in_slice);
335 Iterator out(output, out_slice);
336
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000337 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100338 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100339 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
340 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000341 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100342 {
343 // Set in window
344 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000345 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100346
347 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000348 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100349
350 // Get first input and output slices
351 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000352 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100353
354 do
355 {
356 Iterator in(input, in_slice);
357 Iterator out(output, out_slice);
358
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000359 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000361 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100362 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000363 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100364 {
365 // Set in window
366 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000367 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100368
369 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000370 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371
372 // Get first input and output slices
373 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000374 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100375
376 do
377 {
378 Iterator in(input, in_slice);
379 Iterator out(output, out_slice);
380
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000381 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100382 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000383 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100384 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000385 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100386 {
387 // Set in/out window
388 Window in_window(window);
389 Window out_window(window);
390
391 in_window.set(3, Window::Dimension(0, 1, 1));
392 out_window.set(3, Window::Dimension(0, 1, 1));
393
394 // Get first input and output slices
395 Window in_slice = in_window.first_slice_window_4D();
396 Window out_slice = out_window.first_slice_window_4D();
397
398 do
399 {
400 Iterator in(input, in_slice);
401 Iterator out(output, out_slice);
402
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000403 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100404 }
405 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100406 }
407};
408
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000409template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100410struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100411{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100412 /** NEON vector tag type. */
413 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
414
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000415 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100416 {
417 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000418 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100419 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000420 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100421 case ReductionOperation::ARG_IDX_MAX:
422 case ReductionOperation::ARG_IDX_MIN:
423 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100424 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100425 {
426 init_res_value = *reinterpret_cast<T *>(input.ptr());
427 break;
428 }
429 case ReductionOperation::PROD:
430 {
431 init_res_value = static_cast<T>(1.f);
432 break;
433 }
434 default:
435 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000436 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000437 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000438 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100439
440 execute_window_loop(in_slice, [&](const Coordinates & id)
441 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100442 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
443 const auto vec_elements = wrapper::vloadq(in_ptr);
444
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000445 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100446 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000447 case ReductionOperation::SUM_SQUARE:
448 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
449 break;
450 case ReductionOperation::MEAN_SUM:
451 case ReductionOperation::SUM:
452 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
453 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000454 case ReductionOperation::PROD:
455 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
456 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000457 case ReductionOperation::ARG_IDX_MIN:
458 {
459 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100460 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000461 vec_res_value = temp_vec_res_value;
462 break;
463 }
464 case ReductionOperation::ARG_IDX_MAX:
465 {
466 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100467 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000468 vec_res_value = temp_vec_res_value;
469 break;
470 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100471 case ReductionOperation::MIN:
472 {
473 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
474 break;
475 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100476 case ReductionOperation::MAX:
477 {
478 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
479 break;
480 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000481 default:
482 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100483 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100484 },
485 input);
486
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000487 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000488 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000489 case ReductionOperation::SUM:
490 case ReductionOperation::SUM_SQUARE:
491 case ReductionOperation::MEAN_SUM:
492 {
493 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
494 for(int i = 0; i < S / 4; ++i)
495 {
496 carry_res = wrapper::vpadd(carry_res, carry_res);
497 }
498 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100499
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000500 if(op == ReductionOperation::MEAN_SUM)
501 {
502 res /= in_info.dimension(0);
503 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100504
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000505 *(reinterpret_cast<T *>(output.ptr())) = res;
506 break;
507 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000508 case ReductionOperation::PROD:
509 {
510 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
511 T res = 1;
512 for(int i = 0; i < S / 2; ++i)
513 {
514 res *= wrapper::vgetlane(carry_res, i);
515 }
516 *(reinterpret_cast<T *>(output.ptr())) = res;
517 break;
518 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000519 case ReductionOperation::ARG_IDX_MIN:
520 case ReductionOperation::ARG_IDX_MAX:
521 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100522 auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000523 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
524 break;
525 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100526 case ReductionOperation::MIN:
527 {
528 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
529 break;
530 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100531 case ReductionOperation::MAX:
532 {
533 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0);
534 break;
535 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000536 default:
537 ARM_COMPUTE_ERROR("Not supported");
538 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100539 }
540};
541
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100542struct RedOpX_qasymm8
543{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000544 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100545 {
546 ARM_COMPUTE_UNUSED(out_slice);
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100547
548 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
549
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000550 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
551 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
552 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
553 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100554
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000555 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
556 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
557 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
558 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
559
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000560 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000561
Usama Arif28f0dd92019-05-20 13:44:34 +0100562 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000563 {
564 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
565 }
566
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000567 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100568 execute_window_loop(in_slice, [&](const Coordinates & id)
569 {
570 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000571 switch(op)
572 {
573 case ReductionOperation::SUM:
574 case ReductionOperation::MEAN_SUM:
575 {
576 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
577 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100578
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000579 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
580 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
581 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
582 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100583
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000584 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
585 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
586 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
587 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
588 break;
589 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000590 case ReductionOperation::PROD:
591 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100592 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
593 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000594
595 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
596 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
597
598 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
599 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
600 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
601 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
602
603 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
604 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
605 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
606 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
607
608 //de-quantize vec_elements
609 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
610 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
611 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
612 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
613
614 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
615 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
616 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
617 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
618 break;
619 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000620 case ReductionOperation::ARG_IDX_MIN:
621 {
622 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
623 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
624 vec_res_value = temp_vec_res_value;
625 break;
626 }
627 case ReductionOperation::ARG_IDX_MAX:
628 {
629 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
630 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
631 vec_res_value = temp_vec_res_value;
632 break;
633 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100634 case ReductionOperation::MIN:
635 {
636 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
637 break;
638 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100639 case ReductionOperation::MAX:
640 {
641 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
642 break;
643 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000644 default:
645 ARM_COMPUTE_ERROR("Not supported");
646 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100647 },
648 input);
649
Usama Arifa4a08ad2019-05-20 12:38:33 +0100650 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100651 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100652 case ReductionOperation::ARG_IDX_MIN:
653 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000654 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100655 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
656 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
657 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000658 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100659 case ReductionOperation::MIN:
660 {
661 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
662 break;
663 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100664 case ReductionOperation::MAX:
665 {
666 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
667 break;
668 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100669 case ReductionOperation::PROD:
670 {
671 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
672 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
673 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000674
Usama Arifa4a08ad2019-05-20 12:38:33 +0100675 float res = wrapper::vgetlane(carry_res, 0);
676 res *= wrapper::vgetlane(carry_res, 1);
677 res *= wrapper::vgetlane(carry_res, 2);
678 res *= wrapper::vgetlane(carry_res, 3);
679
680 //re-quantize result
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100681 res = quantize_qasymm8(res, iq_info);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100682 *(output.ptr()) = static_cast<uint8_t>(res);
683 break;
684 }
685 default:
686 {
687 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
688 carry_res = wrapper::vadd(carry_res, vec_res_value3);
689 carry_res = wrapper::vadd(carry_res, vec_res_value4);
690
691 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
692 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
693 auto res = wrapper::vgetlane(carry_paddition, 0);
694
695 if(op == ReductionOperation::MEAN_SUM)
696 {
697 res /= in_info.dimension(0);
698 }
699
700 *(output.ptr()) = static_cast<uint8_t>(res);
701 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000702 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100703 }
704};
705
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000706template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100707struct RedOpYZW
708{
709 /** NEON vector tag type. */
710 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000711 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100712
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000713 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100714 {
715 ARM_COMPUTE_UNUSED(out_slice);
716
giuros01154bc1c2019-03-26 17:44:40 +0000717 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100718 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000719 neon_vector vec_res_value = { 0 };
Usama Arifa4a08ad2019-05-20 12:38:33 +0100720 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000721 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100722 case ReductionOperation::ARG_IDX_MAX:
723 case ReductionOperation::ARG_IDX_MIN:
724 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100725 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100726 {
727 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
728 break;
729 }
730 case ReductionOperation::PROD:
731 {
732 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
733 break;
734 }
735 default:
736 {
737 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
738 break;
739 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000740 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000741 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000742
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100743 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
744 {
745 T *in_ptr;
746 switch(axis)
747 {
748 case 1:
749 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
750 break;
751 case 2:
752 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
753 break;
754 case 3:
755 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
756 break;
757 default:
758 ARM_COMPUTE_ERROR("Not supported");
759 }
760 const auto vec_elements = wrapper::vloadq(in_ptr);
761
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000762 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100763 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000764 case ReductionOperation::SUM:
765 case ReductionOperation::MEAN_SUM:
766 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
767 break;
768 case ReductionOperation::SUM_SQUARE:
769 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
770 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000771 case ReductionOperation::PROD:
772 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
773 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000774 case ReductionOperation::ARG_IDX_MIN:
775 {
776 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
777 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
778 vec_res_value = temp_vec_res_value;
779 break;
780 }
781 case ReductionOperation::ARG_IDX_MAX:
782 {
783 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
784 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
785 vec_res_value = temp_vec_res_value;
786 break;
787 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100788 case ReductionOperation::MIN:
789 {
790 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
791 break;
792 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100793 case ReductionOperation::MAX:
794 {
795 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
796 break;
797 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000798 default:
799 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100800 }
801 }
802
803 if(op == ReductionOperation::MEAN_SUM)
804 {
805 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000806 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100807 }
808
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000809 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
810 {
811 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
812 }
813 else
814 {
815 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
816 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100817 },
818 input, output);
819 }
820};
821
giuros01154bc1c2019-03-26 17:44:40 +0000822template <typename T, int S, int axis, ReductionOperation op>
823struct RedOpYZW_complex
824{
825 /** NEON vector tag type. */
826 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
827 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
828
829 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
830 {
831 ARM_COMPUTE_UNUSED(out_slice);
832 ARM_COMPUTE_ERROR_ON(axis != 2);
833
834 const size_t stride_z = in_info.strides_in_bytes()[axis];
835
836 execute_window_loop(in_slice, [&](const Coordinates &)
837 {
838 neon_vector vec_res_value_0 = { 0 };
839 neon_vector vec_res_value_1 = { 0 };
840
841 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
842 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
843
844 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
845 {
846 T *in_ptr_0;
847 T *in_ptr_1;
848 switch(axis)
849 {
850 case 2:
851 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
852 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
853 break;
854 default:
855 ARM_COMPUTE_ERROR("Not supported");
856 }
857 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
858 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
859
860 switch(op)
861 {
862 case ReductionOperation::SUM:
863 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
864 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
865 break;
866 default:
867 ARM_COMPUTE_ERROR("Not supported");
868 }
869 }
870
871 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
872 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
873
874 },
875 input, output);
876 }
877};
878
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100879struct RedOpYZW_qasymm8
880{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000881 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100882 {
883 ARM_COMPUTE_UNUSED(out_slice);
884
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100885 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
886
giuros01154bc1c2019-03-26 17:44:40 +0000887 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100888 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000889 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000890 auto vec_res_value1 = vdupq_n_u32(0);
891 auto vec_res_value2 = vdupq_n_u32(0);
892 auto vec_res_value3 = vdupq_n_u32(0);
893 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000894
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000895 auto vec_res_value1_f = vdupq_n_f32(1);
896 auto vec_res_value2_f = vdupq_n_f32(1);
897 auto vec_res_value3_f = vdupq_n_f32(1);
898 auto vec_res_value4_f = vdupq_n_f32(1);
899
900 auto vec_res_value = wrapper::vloadq(input.ptr());
901
902 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100903 {
904 uint8_t *in_ptr;
905 switch(axis)
906 {
907 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000908 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100909 break;
910 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000911 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100912 break;
913 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000914 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100915 break;
916 default:
917 ARM_COMPUTE_ERROR("Not supported");
918 }
919 const auto vec_elements = wrapper::vloadq(in_ptr);
920
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000921 switch(op)
922 {
923 case ReductionOperation::SUM:
924 case ReductionOperation::MEAN_SUM:
925 {
926 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
927 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100928
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000929 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
930 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
931 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
932 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100933
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000934 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
935 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
936 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
937 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
938 break;
939 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000940 case ReductionOperation::PROD:
941 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100942 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
943 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000944
945 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
946 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
947
948 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
949 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
950 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
951 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
952
953 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
954 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
955 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
956 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
957
958 //de-quantize vec_elements
959 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
960 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
961 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
962 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
963
964 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
965 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
966 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
967 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
968 break;
969 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000970 case ReductionOperation::ARG_IDX_MIN:
971 {
972 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000973 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000974 vec_res_value = temp_vec_res_value;
975 break;
976 }
977 case ReductionOperation::ARG_IDX_MAX:
978 {
979 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000980 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000981 vec_res_value = temp_vec_res_value;
982 break;
983 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100984 case ReductionOperation::MIN:
985 {
986 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
987 break;
988 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100989 case ReductionOperation::MAX:
990 {
991 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
992 break;
993 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000994 default:
995 ARM_COMPUTE_ERROR("Not supported");
996 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100997 }
998
999 if(op == ReductionOperation::MEAN_SUM)
1000 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001001 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
1002 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
1003 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
1004 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
1005 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001006
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001007 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
1008 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1009 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1010 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1011 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001012 else if(op == ReductionOperation::PROD)
1013 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001014 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
1015 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001016
1017 //re-quantize
1018 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1019 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1020 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1021 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1022
1023 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
1024 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1025 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1026 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1027 }
1028
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001029 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1030 {
1031 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
1032 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
1033 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
1034 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
1035 }
Usama Arifa4a08ad2019-05-20 12:38:33 +01001036 else if(op == ReductionOperation::ARG_IDX_MIN)
1037 {
1038 wrapper::vstore(output.ptr(), vec_res_value);
1039 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001040 else
1041 {
1042 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1043 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1044 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1045 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001046 }
1047
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001048 },
1049 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001050 }
1051};
1052
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001053void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001054{
giuros01154bc1c2019-03-26 17:44:40 +00001055 const bool is_complex = (input->info()->num_channels() == 2);
1056
1057 if(is_complex)
1058 {
1059 switch(axis)
1060 {
1061 case 2:
1062 switch(input->info()->data_type())
1063 {
1064 case DataType::F32:
1065 switch(op)
1066 {
1067 case ReductionOperation::SUM:
1068 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1069 default:
1070 ARM_COMPUTE_ERROR("Not supported");
1071 }
1072 default:
1073 ARM_COMPUTE_ERROR("Not supported");
1074 }
1075 default:
1076 ARM_COMPUTE_ERROR("Not supported");
1077 }
1078 }
1079
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001080 switch(axis)
1081 {
1082 case 0:
1083 switch(input->info()->data_type())
1084 {
1085 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001086 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001087#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1088 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001089 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001090#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1091 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001092 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001093 case DataType::S32:
1094 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001095 default:
1096 ARM_COMPUTE_ERROR("Not supported");
1097 }
1098 case 1:
1099 switch(input->info()->data_type())
1100 {
1101 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001102 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001103#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1104 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001105 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001106#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1107 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001108 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001109 case DataType::S32:
1110 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001111 default:
1112 ARM_COMPUTE_ERROR("Not supported");
1113 }
1114 case 2:
1115 switch(input->info()->data_type())
1116 {
1117 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001118 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001119#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1120 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001121 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001122#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1123 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001124 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001125 case DataType::S32:
1126 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001127 default:
1128 ARM_COMPUTE_ERROR("Not supported");
1129 }
1130 case 3:
1131 switch(input->info()->data_type())
1132 {
1133 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001134 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001135#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1136 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001137 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001138#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1139 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001140 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001141 case DataType::S32:
1142 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001143 default:
1144 ARM_COMPUTE_ERROR("Not supported");
1145 }
1146 default:
1147 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1148 }
1149}
John Richardson73d4aef2018-05-08 14:34:33 +01001150
1151Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1152{
1153 ARM_COMPUTE_UNUSED(op);
1154
1155 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001156 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001157
1158 if(input->num_channels() == 1)
1159 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001160 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001161 }
1162 else
1163 {
1164 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1165 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1166 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1167 }
John Richardson73d4aef2018-05-08 14:34:33 +01001168
1169 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001170 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001171
1172 if(output->total_size() != 0)
1173 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001174 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1175 if(!is_arg_min_max)
1176 {
1177 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001178 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001179 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001180 }
1181 else
1182 {
1183 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
1184 }
John Richardson73d4aef2018-05-08 14:34:33 +01001185
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001186 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001187 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1188 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1189 }
1190
1191 return Status{};
1192}
1193
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001194std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001195{
1196 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001197 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001198
1199 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001200 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1201 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
giuros01154bc1c2019-03-26 17:44:40 +00001202 auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001203
1204 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
1205
1206 // Configure kernel window
1207 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
1208 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
1209 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
1210
1211 bool window_changed = update_window_and_padding(win, input_access, output_access);
1212 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1213
1214 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1215
1216 return std::make_tuple(err, win);
1217}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001218} // namespace
1219
1220NEReductionOperationKernel::NEReductionOperationKernel()
1221 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
1222{
1223}
1224
1225BorderSize NEReductionOperationKernel::border_size() const
1226{
1227 return _border_size;
1228}
1229
1230void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1231{
1232 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001233
John Richardson73d4aef2018-05-08 14:34:33 +01001234 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001235
1236 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
1237
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001238 _input = input;
1239 _output = output;
1240 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
1241 _op = op;
1242 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001243
1244 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001245 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001246
John Richardson73d4aef2018-05-08 14:34:33 +01001247 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001248
John Richardson73d4aef2018-05-08 14:34:33 +01001249 INEKernel::configure(std::get<1>(win_config));
1250}
1251
1252Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1253{
1254 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001255 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001256
1257 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001258}
1259
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001260void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001261{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001262 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001263 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1264 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1265
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001266 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001267}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001268} // namespace arm_compute