blob: a765535c700c3bd9fbbdf021e883239fdb1bf746 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/INEKernel.h"
31#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010032#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010033#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000034#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010035
Michalis Spyroubcf8a962018-10-12 10:51:31 +010036#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010037#include <arm_neon.h>
38
Michalis Spyroubcf8a962018-10-12 10:51:31 +010039namespace arm_compute
40{
Georgios Pinitasd9769582017-08-03 10:19:40 +010041namespace
42{
Michalis Spyrouaea14c62019-01-03 11:10:25 +000043uint32x4x4_t calculate_index(uint32_t idx, float32x4_t a, float32x4_t b, uint32x4x4_t c, ReductionOperation op, int axis)
44{
45 uint32x4_t mask{ 0 };
46 if(op == ReductionOperation::ARG_IDX_MIN)
47 {
48 mask = wrapper::vcgt(b, a);
49 }
50 else
51 {
52 mask = wrapper::vclt(b, a);
53 }
54
55 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
56 if(axis != 0)
57 {
58 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
59 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000060 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000061
62 return res;
63}
64
65uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x4_t c, ReductionOperation op, int axis)
66{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000067 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000068 uint8x16_t mask_u8{ 0 };
69 if(op == ReductionOperation::ARG_IDX_MIN)
70 {
71 mask_u8 = wrapper::vcgt(b, a);
72 }
73 else
74 {
75 mask_u8 = wrapper::vclt(b, a);
76 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000077 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
78 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
79 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
80 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
81 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
82 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
83
Michalis Spyrouaea14c62019-01-03 11:10:25 +000084 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
85 { idx + 4, idx + 5, idx + 6, idx + 7 },
86 { idx + 8, idx + 9, idx + 10, idx + 11 },
87 { idx + 12, idx + 13, idx + 14, idx + 15 }
88 }
89 };
90 if(axis != 0)
91 {
92 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
93 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
94 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
95 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
96 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000097 uint32x4x4_t res =
98 {
99 {
100 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
101 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
102 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
103 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
104 }
105 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000106
107 return res;
108}
109
110uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_value, ReductionOperation op)
111{
112 uint32x4_t res_idx_mask{ 0 };
113 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
114
115 if(op == ReductionOperation::ARG_IDX_MIN)
116 {
117 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
118 pmin = wrapper::vpmin(pmin, pmin);
119 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
120 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
121 }
122 else
123 {
124 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
125 pmax = wrapper::vpmax(pmax, pmax);
126 auto mask = vceqq_f32(vec_res_value, wrapper::vcombine(pmax, pmax));
127 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
128 }
129
130 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
131 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
132 pmin = wrapper::vpmin(pmin, pmin);
133 uint32_t res = wrapper::vgetlane(pmin, 0);
134
135 return (res - 0xFFFFFFFF);
136}
137
138uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
139{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000140 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000141 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
142 uint8x16_t mask_u8{ 0 };
143 if(op == ReductionOperation::ARG_IDX_MIN)
144 {
145 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
146 pmin = wrapper::vpmin(pmin, pmin);
147 pmin = wrapper::vpmin(pmin, pmin);
148 pmin = wrapper::vpmin(pmin, pmin);
149 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
150 }
151 else
152 {
153 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
154 pmax = wrapper::vpmax(pmax, pmax);
155 pmax = wrapper::vpmax(pmax, pmax);
156 pmax = wrapper::vpmax(pmax, pmax);
157 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
158 }
159
160 // Widen vectors
161 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
162 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
163 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
164 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
165 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
166 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
167 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
168 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
169 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
170 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
171 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
172 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
173 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
174 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
175
176 uint32_t res = 0xFFFFFFFF;
177 int iter = 0;
178 do
179 {
180 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
181 pmin = wrapper::vpmin(pmin, pmin);
182 res = std::min(wrapper::vgetlane(pmin, 0), res);
183 iter++;
184 }
185 while(iter < 4);
186
187 return (res - 0xFFFFFFFF);
188}
189#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
190uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
191{
192 uint32x4x2_t mask{ 0 };
193 uint16x8_t mask_u16{ 0 };
194 if(op == ReductionOperation::ARG_IDX_MIN)
195 {
196 mask_u16 = wrapper::vcgt(b, a);
197 }
198 else
199 {
200 mask_u16 = wrapper::vclt(b, a);
201 }
202 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
203 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
204 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
205 { idx + 4, idx + 5, idx + 6, idx + 7 }
206 }
207 };
208 if(axis != 0)
209 {
210 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
211 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
212 }
213 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
214 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
215 0, 0
216 };
217
218 return res;
219}
220
221uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
222{
223 uint32x4x2_t res_idx_mask{ 0 };
224 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
225 uint16x8_t mask_u16;
226 if(op == ReductionOperation::ARG_IDX_MIN)
227 {
228 auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
229 pmin = wrapper::vpmin(pmin, pmin);
230 pmin = wrapper::vpmin(pmin, pmin);
231 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
232 }
233 else
234 {
235 auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
236 pmax = wrapper::vpmax(pmax, pmax);
237 pmax = wrapper::vpmax(pmax, pmax);
238 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
239 }
240
241 // Widen vectors
242 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
243 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
244 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
245 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
246 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
247 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
248
249 uint32_t res = 0xFFFFFFFF;
250 int iter = 0;
251 do
252 {
253 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
254 pmin = wrapper::vpmin(pmin, pmin);
255 res = std::min(wrapper::vgetlane(pmin, 0), res);
256 iter++;
257 }
258 while(iter < 2);
259
260 return (res - 0xFFFFFFFF);
261}
262#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
263
Georgios Pinitasd9769582017-08-03 10:19:40 +0100264template <class F>
265class Reducer
266{
267public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000268 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100269 {
270 // Set out window
271 Window out_window(window);
272 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
273
274 // Get first input and output slices
275 Window in_slice = window.first_slice_window_1D();
276 Window out_slice = out_window.first_slice_window_1D();
277
278 do
279 {
280 Iterator in(input, in_slice);
281 Iterator out(output, out_slice);
282
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000283 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100284 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100285 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
286 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000287 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100288 {
289 // Set in window
290 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000291 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100292
293 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000294 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100295
296 // Get first input and output slices
297 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000298 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100299
300 do
301 {
302 Iterator in(input, in_slice);
303 Iterator out(output, out_slice);
304
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000305 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100306 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000307 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100308 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000309 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100310 {
311 // Set in window
312 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000313 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100314
315 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000316 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100317
318 // Get first input and output slices
319 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000320 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100321
322 do
323 {
324 Iterator in(input, in_slice);
325 Iterator out(output, out_slice);
326
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000327 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100328 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000329 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100330 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000331 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100332 {
333 // Set in/out window
334 Window in_window(window);
335 Window out_window(window);
336
337 in_window.set(3, Window::Dimension(0, 1, 1));
338 out_window.set(3, Window::Dimension(0, 1, 1));
339
340 // Get first input and output slices
341 Window in_slice = in_window.first_slice_window_4D();
342 Window out_slice = out_window.first_slice_window_4D();
343
344 do
345 {
346 Iterator in(input, in_slice);
347 Iterator out(output, out_slice);
348
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000349 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100350 }
351 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100352 }
353};
354
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000355template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100356struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100357{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358 /** NEON vector tag type. */
359 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
360
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000361 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100362 {
363 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000364 auto init_res_value = static_cast<T>(0.f);
365 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
366 {
367 init_res_value = *reinterpret_cast<T *>(input.ptr());
368 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000369 else if(op == ReductionOperation::PROD)
370 {
371 init_res_value = static_cast<T>(1.f);
372 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000373 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000374 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100375
376 execute_window_loop(in_slice, [&](const Coordinates & id)
377 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100378 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
379 const auto vec_elements = wrapper::vloadq(in_ptr);
380
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000381 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100382 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000383 case ReductionOperation::SUM_SQUARE:
384 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
385 break;
386 case ReductionOperation::MEAN_SUM:
387 case ReductionOperation::SUM:
388 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
389 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000390 case ReductionOperation::PROD:
391 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
392 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000393 case ReductionOperation::ARG_IDX_MIN:
394 {
395 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
396 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
397 vec_res_value = temp_vec_res_value;
398 break;
399 }
400 case ReductionOperation::ARG_IDX_MAX:
401 {
402 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
403 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
404 vec_res_value = temp_vec_res_value;
405 break;
406 }
407 default:
408 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100409 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100410 },
411 input);
412
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000413 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000414 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000415 case ReductionOperation::SUM:
416 case ReductionOperation::SUM_SQUARE:
417 case ReductionOperation::MEAN_SUM:
418 {
419 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
420 for(int i = 0; i < S / 4; ++i)
421 {
422 carry_res = wrapper::vpadd(carry_res, carry_res);
423 }
424 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100425
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000426 if(op == ReductionOperation::MEAN_SUM)
427 {
428 res /= in_info.dimension(0);
429 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100430
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000431 *(reinterpret_cast<T *>(output.ptr())) = res;
432 break;
433 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000434 case ReductionOperation::PROD:
435 {
436 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
437 T res = 1;
438 for(int i = 0; i < S / 2; ++i)
439 {
440 res *= wrapper::vgetlane(carry_res, i);
441 }
442 *(reinterpret_cast<T *>(output.ptr())) = res;
443 break;
444 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000445 case ReductionOperation::ARG_IDX_MIN:
446 case ReductionOperation::ARG_IDX_MAX:
447 {
448 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
449 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
450 break;
451 }
452 default:
453 ARM_COMPUTE_ERROR("Not supported");
454 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100455 }
456};
457
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100458struct RedOpX_qasymm8
459{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000460 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100461 {
462 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000463 auto vec_res_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
464 auto vec_res_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
465 auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
466 auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100467
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000468 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
469 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
470 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
471 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
472
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000473 uint8x16_t vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000474
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000475 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
476 {
477 vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
478 }
479
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000480 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100481 execute_window_loop(in_slice, [&](const Coordinates & id)
482 {
483 const auto vec_elements = wrapper::vloadq(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000484 switch(op)
485 {
486 case ReductionOperation::SUM:
487 case ReductionOperation::MEAN_SUM:
488 {
489 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
490 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100491
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000492 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
493 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
494 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
495 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100496
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000497 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
498 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
499 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
500 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
501 break;
502 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000503 case ReductionOperation::PROD:
504 {
505 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
506 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
507
508 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
509 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
510
511 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
512 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
513 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
514 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
515
516 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
517 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
518 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
519 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
520
521 //de-quantize vec_elements
522 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
523 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
524 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
525 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
526
527 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
528 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
529 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
530 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
531 break;
532 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000533 case ReductionOperation::ARG_IDX_MIN:
534 {
535 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
536 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
537 vec_res_value = temp_vec_res_value;
538 break;
539 }
540 case ReductionOperation::ARG_IDX_MAX:
541 {
542 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
543 vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
544 vec_res_value = temp_vec_res_value;
545 break;
546 }
547 default:
548 ARM_COMPUTE_ERROR("Not supported");
549 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100550 },
551 input);
552
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000553 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100554 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000555 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
556 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100557 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000558 else if(op == ReductionOperation::PROD)
559 {
560 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
561 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
562 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
563
564 float res = wrapper::vgetlane(carry_res, 0);
565 res *= wrapper::vgetlane(carry_res, 1);
566 res *= wrapper::vgetlane(carry_res, 2);
567 res *= wrapper::vgetlane(carry_res, 3);
568
569 //re-quantize result
570 res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
571 *(output.ptr()) = static_cast<uint8_t>(res);
572 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000573 else
574 {
575 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
576 carry_res = wrapper::vadd(carry_res, vec_res_value3);
577 carry_res = wrapper::vadd(carry_res, vec_res_value4);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100578
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000579 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
580 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
581 auto res = wrapper::vgetlane(carry_paddition, 0);
582
583 if(op == ReductionOperation::MEAN_SUM)
584 {
585 res /= in_info.dimension(0);
586 }
587
588 *(output.ptr()) = static_cast<uint8_t>(res);
589 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100590 }
591};
592
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000593template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100594struct RedOpYZW
595{
596 /** NEON vector tag type. */
597 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000598 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100599
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000600 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100601 {
602 ARM_COMPUTE_UNUSED(out_slice);
603
604 execute_window_loop(in_slice, [&](const Coordinates & id)
605 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000606 neon_vector vec_res_value = { 0 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000607 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
608 {
609 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
610 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000611 else if(op == ReductionOperation::PROD)
612 {
613 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
614 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000615 else
616 {
617 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
618 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000619 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000620
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100621 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
622 {
623 T *in_ptr;
624 switch(axis)
625 {
626 case 1:
627 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
628 break;
629 case 2:
630 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
631 break;
632 case 3:
633 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
634 break;
635 default:
636 ARM_COMPUTE_ERROR("Not supported");
637 }
638 const auto vec_elements = wrapper::vloadq(in_ptr);
639
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000640 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100641 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000642 case ReductionOperation::SUM:
643 case ReductionOperation::MEAN_SUM:
644 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
645 break;
646 case ReductionOperation::SUM_SQUARE:
647 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
648 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000649 case ReductionOperation::PROD:
650 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
651 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000652 case ReductionOperation::ARG_IDX_MIN:
653 {
654 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
655 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
656 vec_res_value = temp_vec_res_value;
657 break;
658 }
659 case ReductionOperation::ARG_IDX_MAX:
660 {
661 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
662 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
663 vec_res_value = temp_vec_res_value;
664 break;
665 }
666 default:
667 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100668 }
669 }
670
671 if(op == ReductionOperation::MEAN_SUM)
672 {
673 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000674 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100675 }
676
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000677 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
678 {
679 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
680 }
681 else
682 {
683 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
684 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100685 },
686 input, output);
687 }
688};
689
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100690struct RedOpYZW_qasymm8
691{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000692 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100693 {
694 ARM_COMPUTE_UNUSED(out_slice);
695
696 execute_window_loop(in_slice, [&](const Coordinates & id)
697 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000698 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000699 auto vec_res_value1 = vdupq_n_u32(0);
700 auto vec_res_value2 = vdupq_n_u32(0);
701 auto vec_res_value3 = vdupq_n_u32(0);
702 auto vec_res_value4 = vdupq_n_u32(0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000703
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000704 auto vec_res_value1_f = vdupq_n_f32(1);
705 auto vec_res_value2_f = vdupq_n_f32(1);
706 auto vec_res_value3_f = vdupq_n_f32(1);
707 auto vec_res_value4_f = vdupq_n_f32(1);
708
709 auto vec_res_value = wrapper::vloadq(input.ptr());
710
711 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100712 {
713 uint8_t *in_ptr;
714 switch(axis)
715 {
716 case 1:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000717 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100718 break;
719 case 2:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000720 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100721 break;
722 case 3:
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000723 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100724 break;
725 default:
726 ARM_COMPUTE_ERROR("Not supported");
727 }
728 const auto vec_elements = wrapper::vloadq(in_ptr);
729
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000730 switch(op)
731 {
732 case ReductionOperation::SUM:
733 case ReductionOperation::MEAN_SUM:
734 {
735 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
736 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100737
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000738 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
739 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
740 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
741 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100742
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000743 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
744 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
745 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
746 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
747 break;
748 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000749 case ReductionOperation::PROD:
750 {
751 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
752 const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
753
754 const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
755 const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
756
757 const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
758 const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
759 const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
760 const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
761
762 auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
763 auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
764 auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
765 auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
766
767 //de-quantize vec_elements
768 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
769 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
770 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
771 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
772
773 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
774 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
775 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
776 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
777 break;
778 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000779 case ReductionOperation::ARG_IDX_MIN:
780 {
781 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000782 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000783 vec_res_value = temp_vec_res_value;
784 break;
785 }
786 case ReductionOperation::ARG_IDX_MAX:
787 {
788 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000789 vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000790 vec_res_value = temp_vec_res_value;
791 break;
792 }
793 default:
794 ARM_COMPUTE_ERROR("Not supported");
795 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100796 }
797
798 if(op == ReductionOperation::MEAN_SUM)
799 {
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000800 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
801 vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
802 vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
803 vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
804 vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100805
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000806 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
807 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
808 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
809 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
810 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000811 else if(op == ReductionOperation::PROD)
812 {
813 const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
814 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
815
816 //re-quantize
817 vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
818 vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
819 vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
820 vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
821
822 vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
823 vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
824 vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
825 vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
826 }
827
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000828 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
829 {
830 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
831 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
832 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
833 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
834 }
835 else
836 {
837 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
838 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
839 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
840 wrapper::vstore(output.ptr(), res);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100841 }
842
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100843 },
844 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100845 }
846};
847
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000848void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100849{
850 switch(axis)
851 {
852 case 0:
853 switch(input->info()->data_type())
854 {
855 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000856 return Reducer<RedOpX_qasymm8>::reduceX(window, input, output, RedOpX_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100857#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
858 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000859 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100860#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
861 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000862 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100863 default:
864 ARM_COMPUTE_ERROR("Not supported");
865 }
866 case 1:
867 switch(input->info()->data_type())
868 {
869 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000870 return Reducer<RedOpYZW_qasymm8>::reduceY(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100871#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
872 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000873 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100874#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
875 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000876 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100877 default:
878 ARM_COMPUTE_ERROR("Not supported");
879 }
880 case 2:
881 switch(input->info()->data_type())
882 {
883 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000884 return Reducer<RedOpYZW_qasymm8>::reduceZ(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100885#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
886 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000887 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100888#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
889 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000890 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100891 default:
892 ARM_COMPUTE_ERROR("Not supported");
893 }
894 case 3:
895 switch(input->info()->data_type())
896 {
897 case DataType::QASYMM8:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000898 return Reducer<RedOpYZW_qasymm8>::reduceW(window, input, output, RedOpYZW_qasymm8(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100899#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
900 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000901 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100902#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
903 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000904 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100905 default:
906 ARM_COMPUTE_ERROR("Not supported");
907 }
908 default:
909 ARM_COMPUTE_ERROR("Unsupported reduction axis");
910 }
911}
John Richardson73d4aef2018-05-08 14:34:33 +0100912
913Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
914{
915 ARM_COMPUTE_UNUSED(op);
916
917 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100918 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
John Richardson73d4aef2018-05-08 14:34:33 +0100919
920 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100921 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +0100922
923 if(output->total_size() != 0)
924 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000925 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
926 if(!is_arg_min_max)
927 {
928 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +0000929 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000930 }
931 else
932 {
933 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
934 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100935 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100936
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000937 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100938 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
939 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
940 }
941
942 return Status{};
943}
944
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000945std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +0100946{
947 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000948 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +0100949
950 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000951 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
952 DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
Isabella Gottardi0a1090a2019-02-14 18:07:36 +0000953 auto_init_if_empty(*output, output_shape, 1, output_data_type, input->quantization_info());
John Richardson73d4aef2018-05-08 14:34:33 +0100954
955 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
956
957 // Configure kernel window
958 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
959 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
960 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
961
962 bool window_changed = update_window_and_padding(win, input_access, output_access);
963 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
964
965 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
966
967 return std::make_tuple(err, win);
968}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100969} // namespace
970
971NEReductionOperationKernel::NEReductionOperationKernel()
972 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
973{
974}
975
976BorderSize NEReductionOperationKernel::border_size() const
977{
978 return _border_size;
979}
980
981void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
982{
983 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100984
John Richardson73d4aef2018-05-08 14:34:33 +0100985 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100986
987 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
988
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100989 _input = input;
990 _output = output;
991 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
992 _op = op;
993 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100994
995 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000996 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100997
John Richardson73d4aef2018-05-08 14:34:33 +0100998 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100999
John Richardson73d4aef2018-05-08 14:34:33 +01001000 INEKernel::configure(std::get<1>(win_config));
1001}
1002
1003Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1004{
1005 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001006 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001007
1008 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001009}
1010
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001011void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001012{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001013 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001014 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1015 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1016
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001017 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001018}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001019} // namespace arm_compute