blob: 476b3c8720ba9f6090786fa3672b0481b0f5ebb8 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/INEKernel.h"
31#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010032#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010033#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000034#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010035
Michalis Spyroubcf8a962018-10-12 10:51:31 +010036#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010037#include <arm_neon.h>
38
Michalis Spyroubcf8a962018-10-12 10:51:31 +010039namespace arm_compute
40{
Georgios Pinitasd9769582017-08-03 10:19:40 +010041namespace
42{
Michalis Spyrouaea14c62019-01-03 11:10:25 +000043uint32x4x4_t calculate_index(uint32_t idx, float32x4_t a, float32x4_t b, uint32x4x4_t c, ReductionOperation op, int axis)
44{
45 uint32x4_t mask{ 0 };
46 if(op == ReductionOperation::ARG_IDX_MIN)
47 {
48 mask = wrapper::vcgt(b, a);
49 }
50 else
51 {
52 mask = wrapper::vclt(b, a);
53 }
54
55 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
56 if(axis != 0)
57 {
58 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
59 }
60 uint32x4x4_t res = { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 };
61
62 return res;
63}
64
65uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
66{
67 uint32x4x4_t mask{ 0 };
68 uint8x16_t mask_u8{ 0 };
69 if(op == ReductionOperation::ARG_IDX_MIN)
70 {
71 mask_u8 = wrapper::vcgt(b, a);
72 }
73 else
74 {
75 mask_u8 = wrapper::vclt(b, a);
76 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000077 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
78 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
79 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
80 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
81 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
82 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
83
Michalis Spyrouaea14c62019-01-03 11:10:25 +000084 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
85 { idx + 4, idx + 5, idx + 6, idx + 7 },
86 { idx + 8, idx + 9, idx + 10, idx + 11 },
87 { idx + 12, idx + 13, idx + 14, idx + 15 }
88 }
89 };
90 if(axis != 0)
91 {
92 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
93 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
94 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
95 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 }
97 uint32x4x4_t res = { vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
98 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
99 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
100 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
101 };
102
103 return res;
104}
105
106uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_value, ReductionOperation op)
107{
108 uint32x4_t res_idx_mask{ 0 };
109 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
110
111 if(op == ReductionOperation::ARG_IDX_MIN)
112 {
113 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
114 pmin = wrapper::vpmin(pmin, pmin);
115 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
116 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
117 }
118 else
119 {
120 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
121 pmax = wrapper::vpmax(pmax, pmax);
122 auto mask = vceqq_f32(vec_res_value, wrapper::vcombine(pmax, pmax));
123 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
124 }
125
126 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
127 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
128 pmin = wrapper::vpmin(pmin, pmin);
129 uint32_t res = wrapper::vgetlane(pmin, 0);
130
131 return (res - 0xFFFFFFFF);
132}
133
134uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
135{
136 uint32x4x4_t res_idx_mask{ 0 };
137 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
138 uint8x16_t mask_u8{ 0 };
139 if(op == ReductionOperation::ARG_IDX_MIN)
140 {
141 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
142 pmin = wrapper::vpmin(pmin, pmin);
143 pmin = wrapper::vpmin(pmin, pmin);
144 pmin = wrapper::vpmin(pmin, pmin);
145 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
146 }
147 else
148 {
149 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
150 pmax = wrapper::vpmax(pmax, pmax);
151 pmax = wrapper::vpmax(pmax, pmax);
152 pmax = wrapper::vpmax(pmax, pmax);
153 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
154 }
155
156 // Widen vectors
157 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
158 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
159 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
160 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
161 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
162 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
163 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
164 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
165 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
166 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
167 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
168 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
169 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
170 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
171
172 uint32_t res = 0xFFFFFFFF;
173 int iter = 0;
174 do
175 {
176 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
177 pmin = wrapper::vpmin(pmin, pmin);
178 res = std::min(wrapper::vgetlane(pmin, 0), res);
179 iter++;
180 }
181 while(iter < 4);
182
183 return (res - 0xFFFFFFFF);
184}
185#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
186uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
187{
188 uint32x4x2_t mask{ 0 };
189 uint16x8_t mask_u16{ 0 };
190 if(op == ReductionOperation::ARG_IDX_MIN)
191 {
192 mask_u16 = wrapper::vcgt(b, a);
193 }
194 else
195 {
196 mask_u16 = wrapper::vclt(b, a);
197 }
198 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
199 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
200 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
201 { idx + 4, idx + 5, idx + 6, idx + 7 }
202 }
203 };
204 if(axis != 0)
205 {
206 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
207 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
208 }
209 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
210 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
211 0, 0
212 };
213
214 return res;
215}
216
217uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
218{
219 uint32x4x2_t res_idx_mask{ 0 };
220 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
221 uint16x8_t mask_u16;
222 if(op == ReductionOperation::ARG_IDX_MIN)
223 {
224 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
225 pmin = wrapper::vpmin(pmin, pmin);
226 pmin = wrapper::vpmin(pmin, pmin);
227 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
228 }
229 else
230 {
231 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
232 pmax = wrapper::vpmax(pmax, pmax);
233 pmax = wrapper::vpmax(pmax, pmax);
234 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
235 }
236
237 // Widen vectors
238 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
239 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
240 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
241 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
242 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
243 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
244
245 uint32_t res = 0xFFFFFFFF;
246 int iter = 0;
247 do
248 {
249 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
250 pmin = wrapper::vpmin(pmin, pmin);
251 res = std::min(wrapper::vgetlane(pmin, 0), res);
252 iter++;
253 }
254 while(iter < 2);
255
256 return (res - 0xFFFFFFFF);
257}
258#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
259
Georgios Pinitasd9769582017-08-03 10:19:40 +0100260template <class F>
261class Reducer
262{
263public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000264 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100265 {
266 // Set out window
267 Window out_window(window);
268 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
269
270 // Get first input and output slices
271 Window in_slice = window.first_slice_window_1D();
272 Window out_slice = out_window.first_slice_window_1D();
273
274 do
275 {
276 Iterator in(input, in_slice);
277 Iterator out(output, out_slice);
278
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000279 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100280 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100281 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
282 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000283 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100284 {
285 // Set in window
286 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000287 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100288
289 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000290 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100291
292 // Get first input and output slices
293 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000294 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100295
296 do
297 {
298 Iterator in(input, in_slice);
299 Iterator out(output, out_slice);
300
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000301 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100302 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000303 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100304 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000305 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100306 {
307 // Set in window
308 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000309 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100310
311 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000312 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100313
314 // Get first input and output slices
315 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000316 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100317
318 do
319 {
320 Iterator in(input, in_slice);
321 Iterator out(output, out_slice);
322
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000323 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100324 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000325 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100326 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000327 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100328 {
329 // Set in/out window
330 Window in_window(window);
331 Window out_window(window);
332
333 in_window.set(3, Window::Dimension(0, 1, 1));
334 out_window.set(3, Window::Dimension(0, 1, 1));
335
336 // Get first input and output slices
337 Window in_slice = in_window.first_slice_window_4D();
338 Window out_slice = out_window.first_slice_window_4D();
339
340 do
341 {
342 Iterator in(input, in_slice);
343 Iterator out(output, out_slice);
344
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000345 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100346 }
347 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100348 }
349};
350
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000351template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100352struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100353{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100354 /** NEON vector tag type. */
355 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
356
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000357 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100358 {
359 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000360 auto init_res_value = static_cast<T>(0.f);
361 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
362 {
363 init_res_value = *reinterpret_cast<T *>(input.ptr());
364 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000365 else if(op == ReductionOperation::PROD)
366 {
367 init_res_value = static_cast<T>(1.f);
368 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000369 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
370 uint32x4x4_t vec_res_idx{ 0 };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100371
372 execute_window_loop(in_slice, [&](const Coordinates & id)
373 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100374 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
375 const auto vec_elements = wrapper::vloadq(in_ptr);
376
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000377 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100378 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000379 case ReductionOperation::SUM_SQUARE:
380 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
381 break;
382 case ReductionOperation::MEAN_SUM:
383 case ReductionOperation::SUM:
384 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
385 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000386 case ReductionOperation::PROD:
387 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
388 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000389 case ReductionOperation::ARG_IDX_MIN:
390 {
391 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
392 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
393 vec_res_value = temp_vec_res_value;
394 break;
395 }
396 case ReductionOperation::ARG_IDX_MAX:
397 {
398 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
399 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
400 vec_res_value = temp_vec_res_value;
401 break;
402 }
403 default:
404 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100405 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100406 },
407 input);
408
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000409 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000410 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000411 case ReductionOperation::SUM:
412 case ReductionOperation::SUM_SQUARE:
413 case ReductionOperation::MEAN_SUM:
414 {
415 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
416 for(int i = 0; i < S / 4; ++i)
417 {
418 carry_res = wrapper::vpadd(carry_res, carry_res);
419 }
420 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100421
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000422 if(op == ReductionOperation::MEAN_SUM)
423 {
424 res /= in_info.dimension(0);
425 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100426
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000427 *(reinterpret_cast<T *>(output.ptr())) = res;
428 break;
429 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000430 case ReductionOperation::PROD:
431 {
432 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
433 T res = 1;
434 for(int i = 0; i < S / 2; ++i)
435 {
436 res *= wrapper::vgetlane(carry_res, i);
437 }
438 *(reinterpret_cast<T *>(output.ptr())) = res;
439 break;
440 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000441 case ReductionOperation::ARG_IDX_MIN:
442 case ReductionOperation::ARG_IDX_MAX:
443 {
444 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
445 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
446 break;
447 }
448 default:
449 ARM_COMPUTE_ERROR("Not supported");
450 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100451 }
452};
453
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100454struct RedOpX_qasymm8
455{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000456 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100457 {
458 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000459 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
460 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
461 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
462 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100463
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000464 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
465 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
466 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
467 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
468
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000469 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000470
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000471 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
472 {
473 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
474 }
475
476 uint32x4x4_t vec_res_idx{ 0 };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100477 execute_window_loop(in_slice, [&](const Coordinates & id)
478 {
479 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000480 switch(op)
481 {
482 case ReductionOperation::SUM:
483 case ReductionOperation::MEAN_SUM:
484 {
485 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
486 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100487
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000488 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
489 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
490 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
491 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100492
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000493 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
494 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
495 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
496 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
497 break;
498 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000499 case ReductionOperation::PROD:
500 {
501 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
502 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
503
504 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
505 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
506
507 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
508 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
509 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
510 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
511
512 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
513 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
514 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
515 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
516
517 //de-quantize vec_elements
518 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
519 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
520 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
521 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
522
523 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
524 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
525 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
526 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
527 break;
528 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000529 case ReductionOperation::ARG_IDX_MIN:
530 {
531 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
532 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
533 vec_res_value = temp_vec_res_value;
534 break;
535 }
536 case ReductionOperation::ARG_IDX_MAX:
537 {
538 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
539 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
540 vec_res_value = temp_vec_res_value;
541 break;
542 }
543 default:
544 ARM_COMPUTE_ERROR("Not supported");
545 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100546 },
547 input);
548
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000549 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100550 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000551 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
552 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100553 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000554 else if(op == ReductionOperation::PROD)
555 {
556 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
557 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
558 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
559
560 float res = wrapper::vgetlane(carry_res, 0);
561 res *= wrapper::vgetlane(carry_res, 1);
562 res *= wrapper::vgetlane(carry_res, 2);
563 res *= wrapper::vgetlane(carry_res, 3);
564
565 //re-quantize result
566 res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
567 *(output.ptr()) = static_cast<uint8_t>(res);
568 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000569 else
570 {
571 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
572 carry_res = wrapper::vadd(carry_res, vec_res_value3);
573 carry_res = wrapper::vadd(carry_res, vec_res_value4);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100574
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000575 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
576 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
577 auto res = wrapper::vgetlane(carry_paddition, 0);
578
579 if(op == ReductionOperation::MEAN_SUM)
580 {
581 res /= in_info.dimension(0);
582 }
583
584 *(output.ptr()) = static_cast<uint8_t>(res);
585 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100586 }
587};
588
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000589template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100590struct RedOpYZW
591{
592 /** NEON vector tag type. */
593 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000594 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100595
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000596 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100597 {
598 ARM_COMPUTE_UNUSED(out_slice);
599
600 execute_window_loop(in_slice, [&](const Coordinates & id)
601 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000602 neon_vector vec_res_value = { 0 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000603 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
604 {
605 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
606 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000607 else if(op == ReductionOperation::PROD)
608 {
609 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
610 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000611 else
612 {
613 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
614 }
615 uint32x4x4_t vec_res_idx{ 0 };
616
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100617 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
618 {
619 T *in_ptr;
620 switch(axis)
621 {
622 case 1:
623 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
624 break;
625 case 2:
626 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
627 break;
628 case 3:
629 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
630 break;
631 default:
632 ARM_COMPUTE_ERROR("Not supported");
633 }
634 const auto vec_elements = wrapper::vloadq(in_ptr);
635
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000636 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100637 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000638 case ReductionOperation::SUM:
639 case ReductionOperation::MEAN_SUM:
640 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
641 break;
642 case ReductionOperation::SUM_SQUARE:
643 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
644 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000645 case ReductionOperation::PROD:
646 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
647 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000648 case ReductionOperation::ARG_IDX_MIN:
649 {
650 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
651 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
652 vec_res_value = temp_vec_res_value;
653 break;
654 }
655 case ReductionOperation::ARG_IDX_MAX:
656 {
657 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
658 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
659 vec_res_value = temp_vec_res_value;
660 break;
661 }
662 default:
663 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100664 }
665 }
666
667 if(op == ReductionOperation::MEAN_SUM)
668 {
669 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000670 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100671 }
672
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000673 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
674 {
675 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
676 }
677 else
678 {
679 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
680 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100681 },
682 input, output);
683 }
684};
685
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100686struct RedOpYZW_qasymm8
687{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000688 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100689 {
690 ARM_COMPUTE_UNUSED(out_slice);
691
692 execute_window_loop(in_slice, [&](const Coordinates & id)
693 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000694 uint32x4x4_t vec_res_idx{ 0 };
695 auto vec_res_value1 = vdupq_n_u32(0);
696 auto vec_res_value2 = vdupq_n_u32(0);
697 auto vec_res_value3 = vdupq_n_u32(0);
698 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000699
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000700 auto vec_res_value1_f = vdupq_n_f32(1);
701 auto vec_res_value2_f = vdupq_n_f32(1);
702 auto vec_res_value3_f = vdupq_n_f32(1);
703 auto vec_res_value4_f = vdupq_n_f32(1);
704
705 auto vec_res_value = wrapper::vloadq(input.ptr());
706
707 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100708 {
709 uint8_t *in_ptr;
710 switch(axis)
711 {
712 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000713 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100714 break;
715 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000716 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100717 break;
718 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000719 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100720 break;
721 default:
722 ARM_COMPUTE_ERROR("Not supported");
723 }
724 const auto vec_elements = wrapper::vloadq(in_ptr);
725
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000726 switch(op)
727 {
728 case ReductionOperation::SUM:
729 case ReductionOperation::MEAN_SUM:
730 {
731 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
732 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100733
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000734 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
735 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
736 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
737 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100738
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000739 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
740 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
741 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
742 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
743 break;
744 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000745 case ReductionOperation::PROD:
746 {
747 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
748 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
749
750 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
751 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
752
753 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
754 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
755 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
756 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
757
758 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
759 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
760 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
761 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
762
763 //de-quantize vec_elements
764 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
765 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
766 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
767 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
768
769 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
770 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
771 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
772 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
773 break;
774 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000775 case ReductionOperation::ARG_IDX_MIN:
776 {
777 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000778 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000779 vec_res_value = temp_vec_res_value;
780 break;
781 }
782 case ReductionOperation::ARG_IDX_MAX:
783 {
784 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000785 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000786 vec_res_value = temp_vec_res_value;
787 break;
788 }
789 default:
790 ARM_COMPUTE_ERROR("Not supported");
791 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100792 }
793
794 if(op == ReductionOperation::MEAN_SUM)
795 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000796 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
797 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
798 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
799 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
800 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100801
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000802 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
803 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
804 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
805 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
806 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000807 else if(op == ReductionOperation::PROD)
808 {
809 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
810 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
811
812 //re-quantize
813 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
814 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
815 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
816 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
817
818 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
819 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
820 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
821 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
822 }
823
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000824 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
825 {
826 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
827 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
828 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
829 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
830 }
831 else
832 {
833 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
834 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
835 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
836 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100837 }
838
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100839 },
840 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100841 }
842};
843
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000844void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100845{
846 switch(axis)
847 {
848 case 0:
849 switch(input->info()->data_type())
850 {
851 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000852 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100853#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
854 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000855 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100856#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
857 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000858 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100859 default:
860 ARM_COMPUTE_ERROR("Not supported");
861 }
862 case 1:
863 switch(input->info()->data_type())
864 {
865 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000866 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100867#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
868 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000869 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100870#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
871 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000872 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100873 default:
874 ARM_COMPUTE_ERROR("Not supported");
875 }
876 case 2:
877 switch(input->info()->data_type())
878 {
879 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000880 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100881#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
882 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000883 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100884#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
885 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000886 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100887 default:
888 ARM_COMPUTE_ERROR("Not supported");
889 }
890 case 3:
891 switch(input->info()->data_type())
892 {
893 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000894 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100895#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
896 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000897 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100898#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
899 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000900 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100901 default:
902 ARM_COMPUTE_ERROR("Not supported");
903 }
904 default:
905 ARM_COMPUTE_ERROR("Unsupported reduction axis");
906 }
907}
John Richardson73d4aef2018-05-08 14:34:33 +0100908
909Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
910{
911 ARM_COMPUTE_UNUSED(op);
912
913 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100914 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
John Richardson73d4aef2018-05-08 14:34:33 +0100915
916 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100917 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +0100918
919 if(output->total_size() != 0)
920 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000921 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
922 if(!is_arg_min_max)
923 {
924 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +0000925 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000926 }
927 else
928 {
929 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
930 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100931 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100932
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000933 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100934 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
935 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
936 }
937
938 return Status{};
939}
940
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000941std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +0100942{
943 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000944 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100945
946 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000947 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
948 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
Isabella Gottardi0a1090a2019-02-14 18:07:36 +0000949 auto_init_if_empty(*output, output_shape, 1, output_data_type, input->quantization_info());
John Richardson73d4aef2018-05-08 14:34:33 +0100950
951 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
952
953 // Configure kernel window
954 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
955 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
956 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
957
958 bool window_changed = update_window_and_padding(win, input_access, output_access);
959 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
960
961 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
962
963 return std::make_tuple(err, win);
964}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100965} // namespace
966
967NEReductionOperationKernel::NEReductionOperationKernel()
968 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
969{
970}
971
972BorderSize NEReductionOperationKernel::border_size() const
973{
974 return _border_size;
975}
976
977void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
978{
979 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100980
John Richardson73d4aef2018-05-08 14:34:33 +0100981 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100982
983 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
984
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100985 _input = input;
986 _output = output;
987 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
988 _op = op;
989 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100990
991 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000992 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100993
John Richardson73d4aef2018-05-08 14:34:33 +0100994 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100995
John Richardson73d4aef2018-05-08 14:34:33 +0100996 INEKernel::configure(std::get<1>(win_config));
997}
998
999Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1000{
1001 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001002 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001003
1004 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001005}
1006
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001007void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001008{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001009 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001010 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1011 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1012
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001013 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001014}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001015} // namespace arm_compute