blob: 85abda598d5c5e552f11b81c3231a724d6db56a9 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010034#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036
Michalis Spyroubcf8a962018-10-12 10:51:31 +010037#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038#include <arm_neon.h>
39
Michalis Spyroubcf8a962018-10-12 10:51:31 +010040namespace arm_compute
41{
Georgios Pinitasd9769582017-08-03 10:19:40 +010042namespace
43{
Michalis Spyroub9626ab2019-05-13 17:41:01 +010044template <typename T>
45uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000046{
47 uint32x4_t mask{ 0 };
48 if(op == ReductionOperation::ARG_IDX_MIN)
49 {
50 mask = wrapper::vcgt(b, a);
51 }
52 else
53 {
54 mask = wrapper::vclt(b, a);
55 }
56
57 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
58 if(axis != 0)
59 {
60 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
61 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000062 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000063
64 return res;
65}
66
Georgios Pinitasfad18382019-06-05 15:12:22 +010067template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +000068uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
69{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000070 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000071 uint8x16_t mask_u8{ 0 };
72 if(op == ReductionOperation::ARG_IDX_MIN)
73 {
74 mask_u8 = wrapper::vcgt(b, a);
75 }
76 else
77 {
78 mask_u8 = wrapper::vclt(b, a);
79 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000080 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
81 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
82 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
83 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
84 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
85 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
86
Michalis Spyrouaea14c62019-01-03 11:10:25 +000087 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
88 { idx + 4, idx + 5, idx + 6, idx + 7 },
89 { idx + 8, idx + 9, idx + 10, idx + 11 },
90 { idx + 12, idx + 13, idx + 14, idx + 15 }
91 }
92 };
93 if(axis != 0)
94 {
95 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
97 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
98 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
99 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000100 uint32x4x4_t res =
101 {
102 {
103 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
104 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
105 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
106 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
107 }
108 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000109
110 return res;
111}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100112
113// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
114float32x2_t calculate_min(float32x4_t in)
115{
116 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
117 return wrapper::vpmin(pmin, pmin);
118}
119
120// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
121float32x2_t calculate_max(float32x4_t in)
122{
123 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
124 return wrapper::vpmax(pmax, pmax);
125}
126// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
127int32x2_t calculate_min(int32x4_t in)
128{
129 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
130 return wrapper::vpmin(pmin, pmin);
131}
132
133// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
134int32x2_t calculate_max(int32x4_t in)
135{
136 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
137 return wrapper::vpmax(pmax, pmax);
138}
139
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100140template <typename T>
141uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000142{
143 uint32x4_t res_idx_mask{ 0 };
144 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
145
146 if(op == ReductionOperation::ARG_IDX_MIN)
147 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100148 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000149 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
150 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
151 }
152 else
153 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100154 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100155 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000156 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
157 }
158
159 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
160 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
161 pmin = wrapper::vpmin(pmin, pmin);
162 uint32_t res = wrapper::vgetlane(pmin, 0);
163
164 return (res - 0xFFFFFFFF);
165}
166
Usama Arifa4a08ad2019-05-20 12:38:33 +0100167// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
168inline uint8x8_t calculate_min(uint8x16_t in)
169{
170 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
171 pmin = wrapper::vpmin(pmin, pmin);
172 pmin = wrapper::vpmin(pmin, pmin);
173 return wrapper::vpmin(pmin, pmin);
174}
175// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
176inline uint8x8_t calculate_max(uint8x16_t in)
177{
178 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
179 pmax = wrapper::vpmax(pmax, pmax);
180 pmax = wrapper::vpmax(pmax, pmax);
181 return wrapper::vpmax(pmax, pmax);
182}
183
Usama Arif0a5a57a2019-05-23 14:20:33 +0100184template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000185uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
186{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000187 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000188 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
189 uint8x16_t mask_u8{ 0 };
190 if(op == ReductionOperation::ARG_IDX_MIN)
191 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100192 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000193 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
194 }
195 else
196 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100197 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000198 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
199 }
200
201 // Widen vectors
202 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
203 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
204 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
205 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
206 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
207 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
208 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
209 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
210 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
211 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
212 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
213 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
214 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
215 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
216
217 uint32_t res = 0xFFFFFFFF;
218 int iter = 0;
219 do
220 {
221 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
222 pmin = wrapper::vpmin(pmin, pmin);
223 res = std::min(wrapper::vgetlane(pmin, 0), res);
224 iter++;
225 }
226 while(iter < 4);
227
228 return (res - 0xFFFFFFFF);
229}
230#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100231template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000232uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
233{
234 uint32x4x2_t mask{ 0 };
235 uint16x8_t mask_u16{ 0 };
236 if(op == ReductionOperation::ARG_IDX_MIN)
237 {
238 mask_u16 = wrapper::vcgt(b, a);
239 }
240 else
241 {
242 mask_u16 = wrapper::vclt(b, a);
243 }
244 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
245 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
246 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
247 { idx + 4, idx + 5, idx + 6, idx + 7 }
248 }
249 };
250 if(axis != 0)
251 {
252 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
253 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
254 }
255 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
256 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
257 0, 0
258 };
259
260 return res;
261}
262
Usama Arifa4a08ad2019-05-20 12:38:33 +0100263// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
264inline float16x4_t calculate_min(float16x8_t in)
265{
266 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
267 pmin = wrapper::vpmin(pmin, pmin);
268 return wrapper::vpmin(pmin, pmin);
269}
270// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
271inline float16x4_t calculate_max(float16x8_t in)
272{
273 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
274 pmax = wrapper::vpmax(pmax, pmax);
275 return wrapper::vpmax(pmax, pmax);
276}
277
Usama Arif0a5a57a2019-05-23 14:20:33 +0100278template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000279uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
280{
281 uint32x4x2_t res_idx_mask{ 0 };
282 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
283 uint16x8_t mask_u16;
284 if(op == ReductionOperation::ARG_IDX_MIN)
285 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100286 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000287 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
288 }
289 else
290 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100291 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000292 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
293 }
294
295 // Widen vectors
296 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
297 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
298 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
299 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
300 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
301 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
302
303 uint32_t res = 0xFFFFFFFF;
304 int iter = 0;
305 do
306 {
307 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
308 pmin = wrapper::vpmin(pmin, pmin);
309 res = std::min(wrapper::vgetlane(pmin, 0), res);
310 iter++;
311 }
312 while(iter < 2);
313
314 return (res - 0xFFFFFFFF);
315}
316#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
317
Georgios Pinitasd9769582017-08-03 10:19:40 +0100318template <class F>
319class Reducer
320{
321public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000322 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100323 {
324 // Set out window
325 Window out_window(window);
326 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
327
328 // Get first input and output slices
329 Window in_slice = window.first_slice_window_1D();
330 Window out_slice = out_window.first_slice_window_1D();
331
332 do
333 {
334 Iterator in(input, in_slice);
335 Iterator out(output, out_slice);
336
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000337 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100338 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100339 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
340 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000341 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100342 {
343 // Set in window
344 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000345 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100346
347 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000348 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100349
350 // Get first input and output slices
351 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000352 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100353
354 do
355 {
356 Iterator in(input, in_slice);
357 Iterator out(output, out_slice);
358
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000359 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000361 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100362 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000363 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100364 {
365 // Set in window
366 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000367 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100368
369 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000370 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371
372 // Get first input and output slices
373 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000374 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100375
376 do
377 {
378 Iterator in(input, in_slice);
379 Iterator out(output, out_slice);
380
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000381 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100382 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000383 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100384 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000385 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100386 {
387 // Set in/out window
388 Window in_window(window);
389 Window out_window(window);
390
391 in_window.set(3, Window::Dimension(0, 1, 1));
392 out_window.set(3, Window::Dimension(0, 1, 1));
393
394 // Get first input and output slices
395 Window in_slice = in_window.first_slice_window_4D();
396 Window out_slice = out_window.first_slice_window_4D();
397
398 do
399 {
400 Iterator in(input, in_slice);
401 Iterator out(output, out_slice);
402
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000403 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100404 }
405 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100406 }
407};
408
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000409template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100410struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100411{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100412 /** NEON vector tag type. */
413 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
414
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000415 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100416 {
417 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000418 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100419 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000420 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100421 case ReductionOperation::ARG_IDX_MAX:
422 case ReductionOperation::ARG_IDX_MIN:
423 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100424 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100425 {
426 init_res_value = *reinterpret_cast<T *>(input.ptr());
427 break;
428 }
429 case ReductionOperation::PROD:
430 {
431 init_res_value = static_cast<T>(1.f);
432 break;
433 }
434 default:
435 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000436 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000437 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000438 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100439
440 execute_window_loop(in_slice, [&](const Coordinates & id)
441 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100442 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
443 const auto vec_elements = wrapper::vloadq(in_ptr);
444
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000445 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100446 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000447 case ReductionOperation::SUM_SQUARE:
448 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
449 break;
450 case ReductionOperation::MEAN_SUM:
451 case ReductionOperation::SUM:
452 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
453 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000454 case ReductionOperation::PROD:
455 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
456 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000457 case ReductionOperation::ARG_IDX_MIN:
458 {
459 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100460 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000461 vec_res_value = temp_vec_res_value;
462 break;
463 }
464 case ReductionOperation::ARG_IDX_MAX:
465 {
466 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100467 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000468 vec_res_value = temp_vec_res_value;
469 break;
470 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100471 case ReductionOperation::MIN:
472 {
473 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
474 break;
475 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100476 case ReductionOperation::MAX:
477 {
478 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
479 break;
480 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000481 default:
482 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100483 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100484 },
485 input);
486
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000487 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000488 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000489 case ReductionOperation::SUM:
490 case ReductionOperation::SUM_SQUARE:
491 case ReductionOperation::MEAN_SUM:
492 {
493 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
494 for(int i = 0; i < S / 4; ++i)
495 {
496 carry_res = wrapper::vpadd(carry_res, carry_res);
497 }
498 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100499
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000500 if(op == ReductionOperation::MEAN_SUM)
501 {
502 res /= in_info.dimension(0);
503 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100504
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000505 *(reinterpret_cast<T *>(output.ptr())) = res;
506 break;
507 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000508 case ReductionOperation::PROD:
509 {
510 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
511 T res = 1;
512 for(int i = 0; i < S / 2; ++i)
513 {
514 res *= wrapper::vgetlane(carry_res, i);
515 }
516 *(reinterpret_cast<T *>(output.ptr())) = res;
517 break;
518 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000519 case ReductionOperation::ARG_IDX_MIN:
520 case ReductionOperation::ARG_IDX_MAX:
521 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100522 auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000523 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
524 break;
525 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100526 case ReductionOperation::MIN:
527 {
528 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
529 break;
530 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100531 case ReductionOperation::MAX:
532 {
533 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0);
534 break;
535 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000536 default:
537 ARM_COMPUTE_ERROR("Not supported");
538 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100539 }
540};
541
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100542struct RedOpX_qasymm8
543{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000544 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100545 {
546 ARM_COMPUTE_UNUSED(out_slice);
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100547
548 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
549
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000550 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
551 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
552 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
553 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100554
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000555 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
556 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
557 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
558 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
559
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000560 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000561
Usama Arif28f0dd92019-05-20 13:44:34 +0100562 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000563 {
564 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
565 }
566
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000567 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100568 execute_window_loop(in_slice, [&](const Coordinates & id)
569 {
570 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000571 switch(op)
572 {
573 case ReductionOperation::SUM:
574 case ReductionOperation::MEAN_SUM:
575 {
576 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
577 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100578
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000579 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
580 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
581 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
582 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100583
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000584 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
585 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
586 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
587 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
588 break;
589 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000590 case ReductionOperation::PROD:
591 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100592 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
593 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000594
595 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
596 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
597
598 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
599 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
600 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
601 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
602
603 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
604 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
605 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
606 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
607
608 //de-quantize vec_elements
609 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
610 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
611 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
612 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
613
614 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
615 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
616 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
617 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
618 break;
619 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000620 case ReductionOperation::ARG_IDX_MIN:
621 {
622 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
623 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
624 vec_res_value = temp_vec_res_value;
625 break;
626 }
627 case ReductionOperation::ARG_IDX_MAX:
628 {
629 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
630 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
631 vec_res_value = temp_vec_res_value;
632 break;
633 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100634 case ReductionOperation::MIN:
635 {
636 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
637 break;
638 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100639 case ReductionOperation::MAX:
640 {
641 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
642 break;
643 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000644 default:
645 ARM_COMPUTE_ERROR("Not supported");
646 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100647 },
648 input);
649
Usama Arifa4a08ad2019-05-20 12:38:33 +0100650 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100651 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100652 case ReductionOperation::ARG_IDX_MIN:
653 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000654 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100655 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
656 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
657 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000658 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100659 case ReductionOperation::MIN:
660 {
661 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
662 break;
663 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100664 case ReductionOperation::MAX:
665 {
666 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
667 break;
668 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100669 case ReductionOperation::PROD:
670 {
671 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
672 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
673 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000674
Usama Arifa4a08ad2019-05-20 12:38:33 +0100675 float res = wrapper::vgetlane(carry_res, 0);
676 res *= wrapper::vgetlane(carry_res, 1);
677 res *= wrapper::vgetlane(carry_res, 2);
678 res *= wrapper::vgetlane(carry_res, 3);
679
680 //re-quantize result
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100681 res = quantize_qasymm8(res, iq_info);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100682 *(output.ptr()) = static_cast<uint8_t>(res);
683 break;
684 }
685 default:
686 {
687 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
688 carry_res = wrapper::vadd(carry_res, vec_res_value3);
689 carry_res = wrapper::vadd(carry_res, vec_res_value4);
690
691 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
692 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
693 auto res = wrapper::vgetlane(carry_paddition, 0);
694
695 if(op == ReductionOperation::MEAN_SUM)
696 {
697 res /= in_info.dimension(0);
698 }
699
700 *(output.ptr()) = static_cast<uint8_t>(res);
701 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000702 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100703 }
704};
705
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000706template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100707struct RedOpYZW
708{
709 /** NEON vector tag type. */
710 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000711 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100712
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000713 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100714 {
715 ARM_COMPUTE_UNUSED(out_slice);
716
giuros01154bc1c2019-03-26 17:44:40 +0000717 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100718 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000719 neon_vector vec_res_value = { 0 };
Usama Arifa4a08ad2019-05-20 12:38:33 +0100720 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000721 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100722 case ReductionOperation::ARG_IDX_MAX:
723 case ReductionOperation::ARG_IDX_MIN:
724 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100725 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100726 {
727 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
728 break;
729 }
730 case ReductionOperation::PROD:
731 {
732 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
733 break;
734 }
735 default:
736 {
737 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
738 break;
739 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000740 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000741 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000742
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100743 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
744 {
745 T *in_ptr;
746 switch(axis)
747 {
748 case 1:
749 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
750 break;
751 case 2:
752 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
753 break;
754 case 3:
755 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
756 break;
757 default:
758 ARM_COMPUTE_ERROR("Not supported");
759 }
760 const auto vec_elements = wrapper::vloadq(in_ptr);
761
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000762 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100763 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000764 case ReductionOperation::SUM:
765 case ReductionOperation::MEAN_SUM:
766 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
767 break;
768 case ReductionOperation::SUM_SQUARE:
769 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
770 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000771 case ReductionOperation::PROD:
772 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
773 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000774 case ReductionOperation::ARG_IDX_MIN:
775 {
776 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
777 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
778 vec_res_value = temp_vec_res_value;
779 break;
780 }
781 case ReductionOperation::ARG_IDX_MAX:
782 {
783 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
784 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
785 vec_res_value = temp_vec_res_value;
786 break;
787 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100788 case ReductionOperation::MIN:
789 {
790 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
791 break;
792 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100793 case ReductionOperation::MAX:
794 {
795 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
796 break;
797 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000798 default:
799 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100800 }
801 }
802
803 if(op == ReductionOperation::MEAN_SUM)
804 {
805 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000806 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100807 }
808
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000809 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
810 {
811 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100812#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
813 if(std::is_same<T, float16_t>::value)
814 {
815 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
816 }
817#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000818 }
819 else
820 {
821 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
822 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100823 },
824 input, output);
825 }
826};
827
giuros01154bc1c2019-03-26 17:44:40 +0000828template <typename T, int S, int axis, ReductionOperation op>
829struct RedOpYZW_complex
830{
831 /** NEON vector tag type. */
832 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
833 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
834
835 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
836 {
837 ARM_COMPUTE_UNUSED(out_slice);
838 ARM_COMPUTE_ERROR_ON(axis != 2);
839
840 const size_t stride_z = in_info.strides_in_bytes()[axis];
841
842 execute_window_loop(in_slice, [&](const Coordinates &)
843 {
844 neon_vector vec_res_value_0 = { 0 };
845 neon_vector vec_res_value_1 = { 0 };
846
847 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
848 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
849
850 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
851 {
852 T *in_ptr_0;
853 T *in_ptr_1;
854 switch(axis)
855 {
856 case 2:
857 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
858 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
859 break;
860 default:
861 ARM_COMPUTE_ERROR("Not supported");
862 }
863 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
864 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
865
866 switch(op)
867 {
868 case ReductionOperation::SUM:
869 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
870 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
871 break;
872 default:
873 ARM_COMPUTE_ERROR("Not supported");
874 }
875 }
876
877 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
878 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
879
880 },
881 input, output);
882 }
883};
884
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100885struct RedOpYZW_qasymm8
886{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000887 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100888 {
889 ARM_COMPUTE_UNUSED(out_slice);
890
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100891 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
892
giuros01154bc1c2019-03-26 17:44:40 +0000893 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100894 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000895 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000896 auto vec_res_value1 = vdupq_n_u32(0);
897 auto vec_res_value2 = vdupq_n_u32(0);
898 auto vec_res_value3 = vdupq_n_u32(0);
899 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000900
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000901 auto vec_res_value1_f = vdupq_n_f32(1);
902 auto vec_res_value2_f = vdupq_n_f32(1);
903 auto vec_res_value3_f = vdupq_n_f32(1);
904 auto vec_res_value4_f = vdupq_n_f32(1);
905
906 auto vec_res_value = wrapper::vloadq(input.ptr());
907
908 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100909 {
910 uint8_t *in_ptr;
911 switch(axis)
912 {
913 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000914 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100915 break;
916 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000917 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100918 break;
919 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000920 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100921 break;
922 default:
923 ARM_COMPUTE_ERROR("Not supported");
924 }
925 const auto vec_elements = wrapper::vloadq(in_ptr);
926
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000927 switch(op)
928 {
929 case ReductionOperation::SUM:
930 case ReductionOperation::MEAN_SUM:
931 {
932 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
933 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100934
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000935 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
936 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
937 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
938 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100939
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000940 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
941 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
942 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
943 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
944 break;
945 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000946 case ReductionOperation::PROD:
947 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100948 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
949 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000950
951 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
952 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
953
954 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
955 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
956 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
957 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
958
959 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
960 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
961 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
962 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
963
964 //de-quantize vec_elements
965 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
966 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
967 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
968 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
969
970 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
971 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
972 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
973 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
974 break;
975 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000976 case ReductionOperation::ARG_IDX_MIN:
977 {
978 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000979 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000980 vec_res_value = temp_vec_res_value;
981 break;
982 }
983 case ReductionOperation::ARG_IDX_MAX:
984 {
985 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000986 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000987 vec_res_value = temp_vec_res_value;
988 break;
989 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100990 case ReductionOperation::MIN:
991 {
992 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
993 break;
994 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100995 case ReductionOperation::MAX:
996 {
997 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
998 break;
999 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001000 default:
1001 ARM_COMPUTE_ERROR("Not supported");
1002 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001003 }
1004
1005 if(op == ReductionOperation::MEAN_SUM)
1006 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001007 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
1008 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
1009 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
1010 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
1011 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001012
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001013 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
1014 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1015 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1016 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1017 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001018 else if(op == ReductionOperation::PROD)
1019 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001020 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
1021 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001022
1023 //re-quantize
1024 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1025 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1026 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1027 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1028
1029 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
1030 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
1031 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
1032 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
1033 }
1034
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001035 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1036 {
1037 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
1038 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
1039 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
1040 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
1041 }
Usama Arifa4a08ad2019-05-20 12:38:33 +01001042 else if(op == ReductionOperation::ARG_IDX_MIN)
1043 {
1044 wrapper::vstore(output.ptr(), vec_res_value);
1045 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001046 else
1047 {
1048 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1049 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1050 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1051 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001052 }
1053
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001054 },
1055 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001056 }
1057};
1058
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001059void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001060{
giuros01154bc1c2019-03-26 17:44:40 +00001061 const bool is_complex = (input->info()->num_channels() == 2);
1062
1063 if(is_complex)
1064 {
1065 switch(axis)
1066 {
1067 case 2:
1068 switch(input->info()->data_type())
1069 {
1070 case DataType::F32:
1071 switch(op)
1072 {
1073 case ReductionOperation::SUM:
1074 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1075 default:
1076 ARM_COMPUTE_ERROR("Not supported");
1077 }
1078 default:
1079 ARM_COMPUTE_ERROR("Not supported");
1080 }
1081 default:
1082 ARM_COMPUTE_ERROR("Not supported");
1083 }
1084 }
1085
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001086 switch(axis)
1087 {
1088 case 0:
1089 switch(input->info()->data_type())
1090 {
1091 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001092 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001093#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1094 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001095 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001096#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1097 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001098 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001099 case DataType::S32:
1100 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001101 default:
1102 ARM_COMPUTE_ERROR("Not supported");
1103 }
1104 case 1:
1105 switch(input->info()->data_type())
1106 {
1107 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001108 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001109#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1110 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001111 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001112#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1113 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001114 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001115 case DataType::S32:
1116 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001117 default:
1118 ARM_COMPUTE_ERROR("Not supported");
1119 }
1120 case 2:
1121 switch(input->info()->data_type())
1122 {
1123 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001124 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001125#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1126 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001127 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001128#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1129 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001130 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001131 case DataType::S32:
1132 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001133 default:
1134 ARM_COMPUTE_ERROR("Not supported");
1135 }
1136 case 3:
1137 switch(input->info()->data_type())
1138 {
1139 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001140 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001141#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1142 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001143 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001144#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1145 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001146 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001147 case DataType::S32:
1148 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001149 default:
1150 ARM_COMPUTE_ERROR("Not supported");
1151 }
1152 default:
1153 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1154 }
1155}
John Richardson73d4aef2018-05-08 14:34:33 +01001156
1157Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1158{
1159 ARM_COMPUTE_UNUSED(op);
1160
1161 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001162 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001163
1164 if(input->num_channels() == 1)
1165 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001166 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001167 }
1168 else
1169 {
1170 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1171 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1172 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1173 }
John Richardson73d4aef2018-05-08 14:34:33 +01001174
1175 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001176 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001177
1178 if(output->total_size() != 0)
1179 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001180 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1181 if(!is_arg_min_max)
1182 {
1183 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001184 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001185 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001186 }
1187 else
1188 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001189 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001190 }
John Richardson73d4aef2018-05-08 14:34:33 +01001191
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001192 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001193 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1194 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1195 }
1196
1197 return Status{};
1198}
1199
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001200std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001201{
1202 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001203 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001204
1205 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001206 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1207 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
giuros01154bc1c2019-03-26 17:44:40 +00001208 auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001209
1210 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
1211
1212 // Configure kernel window
1213 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
1214 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
1215 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
1216
1217 bool window_changed = update_window_and_padding(win, input_access, output_access);
1218 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1219
1220 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1221
1222 return std::make_tuple(err, win);
1223}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001224} // namespace
1225
1226NEReductionOperationKernel::NEReductionOperationKernel()
1227 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
1228{
1229}
1230
1231BorderSize NEReductionOperationKernel::border_size() const
1232{
1233 return _border_size;
1234}
1235
1236void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1237{
1238 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001239
John Richardson73d4aef2018-05-08 14:34:33 +01001240 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001241
1242 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
1243
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001244 _input = input;
1245 _output = output;
1246 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
1247 _op = op;
1248 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001249
1250 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001251 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001252
John Richardson73d4aef2018-05-08 14:34:33 +01001253 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001254
John Richardson73d4aef2018-05-08 14:34:33 +01001255 INEKernel::configure(std::get<1>(win_config));
1256}
1257
1258Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1259{
1260 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001261 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001262
1263 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001264}
1265
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001266void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001267{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001268 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001269 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1270 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1271
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001272 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001273}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001274} // namespace arm_compute