blob: 84cb2236abf41a77bdcf567b23b20edbbc634f6a [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010034#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036
Michalis Spyroubcf8a962018-10-12 10:51:31 +010037#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038#include <arm_neon.h>
39
Michalis Spyroubcf8a962018-10-12 10:51:31 +010040namespace arm_compute
41{
Georgios Pinitasd9769582017-08-03 10:19:40 +010042namespace
43{
Michalis Spyrouaea14c62019-01-03 11:10:25 +000044uint32x4x4_t calculate_index(uint32_t idx, float32x4_t a, float32x4_t b, uint32x4x4_t c, ReductionOperation op, int axis)
45{
46 uint32x4_t mask{ 0 };
47 if(op == ReductionOperation::ARG_IDX_MIN)
48 {
49 mask = wrapper::vcgt(b, a);
50 }
51 else
52 {
53 mask = wrapper::vclt(b, a);
54 }
55
56 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
57 if(axis != 0)
58 {
59 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
60 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000061 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000062
63 return res;
64}
65
66uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
67{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000068 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000069 uint8x16_t mask_u8{ 0 };
70 if(op == ReductionOperation::ARG_IDX_MIN)
71 {
72 mask_u8 = wrapper::vcgt(b, a);
73 }
74 else
75 {
76 mask_u8 = wrapper::vclt(b, a);
77 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000078 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
79 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
80 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
81 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
82 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
83 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
84
Michalis Spyrouaea14c62019-01-03 11:10:25 +000085 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
86 { idx + 4, idx + 5, idx + 6, idx + 7 },
87 { idx + 8, idx + 9, idx + 10, idx + 11 },
88 { idx + 12, idx + 13, idx + 14, idx + 15 }
89 }
90 };
91 if(axis != 0)
92 {
93 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
94 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
95 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
97 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000098 uint32x4x4_t res =
99 {
100 {
101 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
102 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
103 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
104 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
105 }
106 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000107
108 return res;
109}
110
111uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_value, ReductionOperation op)
112{
113 uint32x4_t res_idx_mask{ 0 };
114 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
115
116 if(op == ReductionOperation::ARG_IDX_MIN)
117 {
118 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
119 pmin = wrapper::vpmin(pmin, pmin);
120 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
121 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
122 }
123 else
124 {
125 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
126 pmax = wrapper::vpmax(pmax, pmax);
127 auto mask = vceqq_f32(vec_res_value, wrapper::vcombine(pmax, pmax));
128 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
129 }
130
131 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
132 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
133 pmin = wrapper::vpmin(pmin, pmin);
134 uint32_t res = wrapper::vgetlane(pmin, 0);
135
136 return (res - 0xFFFFFFFF);
137}
138
139uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
140{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000141 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000142 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
143 uint8x16_t mask_u8{ 0 };
144 if(op == ReductionOperation::ARG_IDX_MIN)
145 {
146 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
147 pmin = wrapper::vpmin(pmin, pmin);
148 pmin = wrapper::vpmin(pmin, pmin);
149 pmin = wrapper::vpmin(pmin, pmin);
150 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
151 }
152 else
153 {
154 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
155 pmax = wrapper::vpmax(pmax, pmax);
156 pmax = wrapper::vpmax(pmax, pmax);
157 pmax = wrapper::vpmax(pmax, pmax);
158 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
159 }
160
161 // Widen vectors
162 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
163 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
164 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
165 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
166 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
167 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
168 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
169 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
170 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
171 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
172 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
173 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
174 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
175 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
176
177 uint32_t res = 0xFFFFFFFF;
178 int iter = 0;
179 do
180 {
181 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
182 pmin = wrapper::vpmin(pmin, pmin);
183 res = std::min(wrapper::vgetlane(pmin, 0), res);
184 iter++;
185 }
186 while(iter < 4);
187
188 return (res - 0xFFFFFFFF);
189}
190#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
191uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
192{
193 uint32x4x2_t mask{ 0 };
194 uint16x8_t mask_u16{ 0 };
195 if(op == ReductionOperation::ARG_IDX_MIN)
196 {
197 mask_u16 = wrapper::vcgt(b, a);
198 }
199 else
200 {
201 mask_u16 = wrapper::vclt(b, a);
202 }
203 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
204 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
205 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
206 { idx + 4, idx + 5, idx + 6, idx + 7 }
207 }
208 };
209 if(axis != 0)
210 {
211 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
212 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
213 }
214 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
215 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
216 0, 0
217 };
218
219 return res;
220}
221
222uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
223{
224 uint32x4x2_t res_idx_mask{ 0 };
225 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
226 uint16x8_t mask_u16;
227 if(op == ReductionOperation::ARG_IDX_MIN)
228 {
229 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
230 pmin = wrapper::vpmin(pmin, pmin);
231 pmin = wrapper::vpmin(pmin, pmin);
232 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
233 }
234 else
235 {
236 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
237 pmax = wrapper::vpmax(pmax, pmax);
238 pmax = wrapper::vpmax(pmax, pmax);
239 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
240 }
241
242 // Widen vectors
243 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
244 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
245 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
246 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
247 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
248 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
249
250 uint32_t res = 0xFFFFFFFF;
251 int iter = 0;
252 do
253 {
254 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
255 pmin = wrapper::vpmin(pmin, pmin);
256 res = std::min(wrapper::vgetlane(pmin, 0), res);
257 iter++;
258 }
259 while(iter < 2);
260
261 return (res - 0xFFFFFFFF);
262}
263#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
264
Georgios Pinitasd9769582017-08-03 10:19:40 +0100265template <class F>
266class Reducer
267{
268public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000269 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100270 {
271 // Set out window
272 Window out_window(window);
273 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
274
275 // Get first input and output slices
276 Window in_slice = window.first_slice_window_1D();
277 Window out_slice = out_window.first_slice_window_1D();
278
279 do
280 {
281 Iterator in(input, in_slice);
282 Iterator out(output, out_slice);
283
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000284 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100285 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100286 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
287 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000288 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100289 {
290 // Set in window
291 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000292 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100293
294 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000295 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100296
297 // Get first input and output slices
298 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000299 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100300
301 do
302 {
303 Iterator in(input, in_slice);
304 Iterator out(output, out_slice);
305
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000306 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100307 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000308 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100309 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000310 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100311 {
312 // Set in window
313 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000314 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100315
316 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000317 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100318
319 // Get first input and output slices
320 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000321 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100322
323 do
324 {
325 Iterator in(input, in_slice);
326 Iterator out(output, out_slice);
327
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000328 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100329 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000330 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100331 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000332 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100333 {
334 // Set in/out window
335 Window in_window(window);
336 Window out_window(window);
337
338 in_window.set(3, Window::Dimension(0, 1, 1));
339 out_window.set(3, Window::Dimension(0, 1, 1));
340
341 // Get first input and output slices
342 Window in_slice = in_window.first_slice_window_4D();
343 Window out_slice = out_window.first_slice_window_4D();
344
345 do
346 {
347 Iterator in(input, in_slice);
348 Iterator out(output, out_slice);
349
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000350 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100351 }
352 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100353 }
354};
355
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000356template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100357struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100358{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100359 /** NEON vector tag type. */
360 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
361
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000362 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100363 {
364 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000365 auto init_res_value = static_cast<T>(0.f);
366 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
367 {
368 init_res_value = *reinterpret_cast<T *>(input.ptr());
369 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000370 else if(op == ReductionOperation::PROD)
371 {
372 init_res_value = static_cast<T>(1.f);
373 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000374 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000375 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100376
377 execute_window_loop(in_slice, [&](const Coordinates & id)
378 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100379 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
380 const auto vec_elements = wrapper::vloadq(in_ptr);
381
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000382 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100383 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000384 case ReductionOperation::SUM_SQUARE:
385 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
386 break;
387 case ReductionOperation::MEAN_SUM:
388 case ReductionOperation::SUM:
389 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
390 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000391 case ReductionOperation::PROD:
392 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
393 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000394 case ReductionOperation::ARG_IDX_MIN:
395 {
396 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
397 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
398 vec_res_value = temp_vec_res_value;
399 break;
400 }
401 case ReductionOperation::ARG_IDX_MAX:
402 {
403 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
404 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
405 vec_res_value = temp_vec_res_value;
406 break;
407 }
408 default:
409 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100410 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100411 },
412 input);
413
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000414 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000415 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000416 case ReductionOperation::SUM:
417 case ReductionOperation::SUM_SQUARE:
418 case ReductionOperation::MEAN_SUM:
419 {
420 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
421 for(int i = 0; i < S / 4; ++i)
422 {
423 carry_res = wrapper::vpadd(carry_res, carry_res);
424 }
425 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100426
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000427 if(op == ReductionOperation::MEAN_SUM)
428 {
429 res /= in_info.dimension(0);
430 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100431
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000432 *(reinterpret_cast<T *>(output.ptr())) = res;
433 break;
434 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000435 case ReductionOperation::PROD:
436 {
437 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
438 T res = 1;
439 for(int i = 0; i < S / 2; ++i)
440 {
441 res *= wrapper::vgetlane(carry_res, i);
442 }
443 *(reinterpret_cast<T *>(output.ptr())) = res;
444 break;
445 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000446 case ReductionOperation::ARG_IDX_MIN:
447 case ReductionOperation::ARG_IDX_MAX:
448 {
449 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
450 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
451 break;
452 }
453 default:
454 ARM_COMPUTE_ERROR("Not supported");
455 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100456 }
457};
458
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100459struct RedOpX_qasymm8
460{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000461 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100462 {
463 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000464 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
465 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
466 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
467 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100468
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000469 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
470 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
471 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
472 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
473
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000474 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000475
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000476 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
477 {
478 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
479 }
480
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000481 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100482 execute_window_loop(in_slice, [&](const Coordinates & id)
483 {
484 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000485 switch(op)
486 {
487 case ReductionOperation::SUM:
488 case ReductionOperation::MEAN_SUM:
489 {
490 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
491 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100492
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000493 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
494 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
495 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
496 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100497
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000498 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
499 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
500 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
501 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
502 break;
503 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000504 case ReductionOperation::PROD:
505 {
506 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
507 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
508
509 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
510 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
511
512 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
513 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
514 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
515 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
516
517 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
518 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
519 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
520 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
521
522 //de-quantize vec_elements
523 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
524 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
525 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
526 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
527
528 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
529 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
530 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
531 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
532 break;
533 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000534 case ReductionOperation::ARG_IDX_MIN:
535 {
536 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
537 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
538 vec_res_value = temp_vec_res_value;
539 break;
540 }
541 case ReductionOperation::ARG_IDX_MAX:
542 {
543 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
544 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
545 vec_res_value = temp_vec_res_value;
546 break;
547 }
548 default:
549 ARM_COMPUTE_ERROR("Not supported");
550 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100551 },
552 input);
553
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000554 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100555 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000556 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
557 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100558 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000559 else if(op == ReductionOperation::PROD)
560 {
561 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
562 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
563 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
564
565 float res = wrapper::vgetlane(carry_res, 0);
566 res *= wrapper::vgetlane(carry_res, 1);
567 res *= wrapper::vgetlane(carry_res, 2);
568 res *= wrapper::vgetlane(carry_res, 3);
569
570 //re-quantize result
571 res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
572 *(output.ptr()) = static_cast<uint8_t>(res);
573 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000574 else
575 {
576 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
577 carry_res = wrapper::vadd(carry_res, vec_res_value3);
578 carry_res = wrapper::vadd(carry_res, vec_res_value4);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100579
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000580 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
581 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
582 auto res = wrapper::vgetlane(carry_paddition, 0);
583
584 if(op == ReductionOperation::MEAN_SUM)
585 {
586 res /= in_info.dimension(0);
587 }
588
589 *(output.ptr()) = static_cast<uint8_t>(res);
590 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100591 }
592};
593
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000594template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100595struct RedOpYZW
596{
597 /** NEON vector tag type. */
598 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000599 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100600
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000601 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100602 {
603 ARM_COMPUTE_UNUSED(out_slice);
604
605 execute_window_loop(in_slice, [&](const Coordinates & id)
606 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000607 neon_vector vec_res_value = { 0 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000608 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
609 {
610 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
611 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000612 else if(op == ReductionOperation::PROD)
613 {
614 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
615 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000616 else
617 {
618 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
619 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000620 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000621
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100622 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
623 {
624 T *in_ptr;
625 switch(axis)
626 {
627 case 1:
628 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
629 break;
630 case 2:
631 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
632 break;
633 case 3:
634 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
635 break;
636 default:
637 ARM_COMPUTE_ERROR("Not supported");
638 }
639 const auto vec_elements = wrapper::vloadq(in_ptr);
640
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000641 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100642 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000643 case ReductionOperation::SUM:
644 case ReductionOperation::MEAN_SUM:
645 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
646 break;
647 case ReductionOperation::SUM_SQUARE:
648 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
649 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000650 case ReductionOperation::PROD:
651 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
652 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000653 case ReductionOperation::ARG_IDX_MIN:
654 {
655 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
656 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
657 vec_res_value = temp_vec_res_value;
658 break;
659 }
660 case ReductionOperation::ARG_IDX_MAX:
661 {
662 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
663 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
664 vec_res_value = temp_vec_res_value;
665 break;
666 }
667 default:
668 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100669 }
670 }
671
672 if(op == ReductionOperation::MEAN_SUM)
673 {
674 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000675 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100676 }
677
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000678 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
679 {
680 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
681 }
682 else
683 {
684 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
685 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100686 },
687 input, output);
688 }
689};
690
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100691struct RedOpYZW_qasymm8
692{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000693 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100694 {
695 ARM_COMPUTE_UNUSED(out_slice);
696
697 execute_window_loop(in_slice, [&](const Coordinates & id)
698 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000699 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000700 auto vec_res_value1 = vdupq_n_u32(0);
701 auto vec_res_value2 = vdupq_n_u32(0);
702 auto vec_res_value3 = vdupq_n_u32(0);
703 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000704
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000705 auto vec_res_value1_f = vdupq_n_f32(1);
706 auto vec_res_value2_f = vdupq_n_f32(1);
707 auto vec_res_value3_f = vdupq_n_f32(1);
708 auto vec_res_value4_f = vdupq_n_f32(1);
709
710 auto vec_res_value = wrapper::vloadq(input.ptr());
711
712 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100713 {
714 uint8_t *in_ptr;
715 switch(axis)
716 {
717 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000718 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100719 break;
720 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000721 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100722 break;
723 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000724 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100725 break;
726 default:
727 ARM_COMPUTE_ERROR("Not supported");
728 }
729 const auto vec_elements = wrapper::vloadq(in_ptr);
730
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000731 switch(op)
732 {
733 case ReductionOperation::SUM:
734 case ReductionOperation::MEAN_SUM:
735 {
736 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
737 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100738
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000739 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
740 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
741 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
742 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100743
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000744 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
745 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
746 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
747 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
748 break;
749 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000750 case ReductionOperation::PROD:
751 {
752 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
753 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
754
755 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
756 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
757
758 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
759 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
760 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
761 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
762
763 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
764 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
765 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
766 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
767
768 //de-quantize vec_elements
769 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
770 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
771 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
772 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
773
774 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
775 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
776 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
777 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
778 break;
779 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000780 case ReductionOperation::ARG_IDX_MIN:
781 {
782 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000783 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000784 vec_res_value = temp_vec_res_value;
785 break;
786 }
787 case ReductionOperation::ARG_IDX_MAX:
788 {
789 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000790 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000791 vec_res_value = temp_vec_res_value;
792 break;
793 }
794 default:
795 ARM_COMPUTE_ERROR("Not supported");
796 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100797 }
798
799 if(op == ReductionOperation::MEAN_SUM)
800 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000801 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
802 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
803 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
804 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
805 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100806
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000807 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
808 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
809 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
810 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
811 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000812 else if(op == ReductionOperation::PROD)
813 {
814 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
815 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
816
817 //re-quantize
818 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
819 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
820 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
821 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
822
823 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
824 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
825 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
826 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
827 }
828
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000829 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
830 {
831 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
832 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
833 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
834 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
835 }
836 else
837 {
838 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
839 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
840 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
841 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100842 }
843
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100844 },
845 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100846 }
847};
848
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000849void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100850{
851 switch(axis)
852 {
853 case 0:
854 switch(input->info()->data_type())
855 {
856 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000857 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100858#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
859 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000860 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100861#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
862 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000863 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100864 default:
865 ARM_COMPUTE_ERROR("Not supported");
866 }
867 case 1:
868 switch(input->info()->data_type())
869 {
870 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000871 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100872#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
873 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000874 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100875#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
876 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000877 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100878 default:
879 ARM_COMPUTE_ERROR("Not supported");
880 }
881 case 2:
882 switch(input->info()->data_type())
883 {
884 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000885 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100886#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
887 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000888 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100889#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
890 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000891 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100892 default:
893 ARM_COMPUTE_ERROR("Not supported");
894 }
895 case 3:
896 switch(input->info()->data_type())
897 {
898 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000899 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100900#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
901 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000902 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100903#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
904 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000905 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100906 default:
907 ARM_COMPUTE_ERROR("Not supported");
908 }
909 default:
910 ARM_COMPUTE_ERROR("Unsupported reduction axis");
911 }
912}
John Richardson73d4aef2018-05-08 14:34:33 +0100913
914Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
915{
916 ARM_COMPUTE_UNUSED(op);
917
918 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +0000919 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100920 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
John Richardson73d4aef2018-05-08 14:34:33 +0100921
922 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100923 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +0100924
925 if(output->total_size() != 0)
926 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000927 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
928 if(!is_arg_min_max)
929 {
930 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +0000931 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000932 }
933 else
934 {
935 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
936 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100937 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100938
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000939 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100940 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
941 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
942 }
943
944 return Status{};
945}
946
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000947std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +0100948{
949 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000950 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100951
952 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000953 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
954 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
Isabella Gottardi0a1090a2019-02-14 18:07:36 +0000955 auto_init_if_empty(*output, output_shape, 1, output_data_type, input->quantization_info());
John Richardson73d4aef2018-05-08 14:34:33 +0100956
957 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
958
959 // Configure kernel window
960 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
961 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
962 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
963
964 bool window_changed = update_window_and_padding(win, input_access, output_access);
965 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
966
967 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
968
969 return std::make_tuple(err, win);
970}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100971} // namespace
972
973NEReductionOperationKernel::NEReductionOperationKernel()
974 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
975{
976}
977
978BorderSize NEReductionOperationKernel::border_size() const
979{
980 return _border_size;
981}
982
983void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
984{
985 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100986
John Richardson73d4aef2018-05-08 14:34:33 +0100987 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100988
989 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
990
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100991 _input = input;
992 _output = output;
993 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
994 _op = op;
995 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100996
997 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000998 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100999
John Richardson73d4aef2018-05-08 14:34:33 +01001000 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001001
John Richardson73d4aef2018-05-08 14:34:33 +01001002 INEKernel::configure(std::get<1>(win_config));
1003}
1004
1005Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1006{
1007 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001008 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001009
1010 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001011}
1012
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001013void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001014{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001015 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001016 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1017 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1018
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001019 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001020}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001021} // namespace arm_compute