blob: e2dee67d01b7238837f1f65bf735f310b072438b [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Pablo Tello07958782020-01-09 14:43:36 +00002 * Copyright (c) 2017-2020 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Luca Foschianiee939fb2020-01-28 10:38:07 +000034#include "arm_compute/core/Utils.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010035#include "arm_compute/core/Validate.h"
Michalis Spyrou19bd4122020-01-22 10:27:06 +000036#include "arm_compute/core/utils/misc/SaturateCast.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038
Michalis Spyroubcf8a962018-10-12 10:51:31 +010039#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010040#include <arm_neon.h>
41
Michalis Spyroubcf8a962018-10-12 10:51:31 +010042namespace arm_compute
43{
Georgios Pinitasd9769582017-08-03 10:19:40 +010044namespace
45{
Luca Foschianiee939fb2020-01-28 10:38:07 +000046// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
47template <typename T>
48void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output)
49{
50 if(std::is_same<T, uint8_t>::value)
51 {
52 auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
53 wrapper::vstore(output.ptr(), res);
54 }
55 else
56 {
57 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
58 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr()), res);
59 }
60}
61
Michalis Spyroub9626ab2019-05-13 17:41:01 +010062template <typename T>
63uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000064{
65 uint32x4_t mask{ 0 };
66 if(op == ReductionOperation::ARG_IDX_MIN)
67 {
68 mask = wrapper::vcgt(b, a);
69 }
70 else
71 {
72 mask = wrapper::vclt(b, a);
73 }
74
75 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
76 if(axis != 0)
77 {
78 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
79 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000080 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000081
82 return res;
83}
84
Luca Foschianiee939fb2020-01-28 10:38:07 +000085template <typename T>
86uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000087{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000088 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000089 uint8x16_t mask_u8{ 0 };
90 if(op == ReductionOperation::ARG_IDX_MIN)
91 {
92 mask_u8 = wrapper::vcgt(b, a);
93 }
94 else
95 {
96 mask_u8 = wrapper::vclt(b, a);
97 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000098 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
99 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
100 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
101 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
102 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
103 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
104
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000105 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
106 { idx + 4, idx + 5, idx + 6, idx + 7 },
107 { idx + 8, idx + 9, idx + 10, idx + 11 },
108 { idx + 12, idx + 13, idx + 14, idx + 15 }
109 }
110 };
111 if(axis != 0)
112 {
113 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
114 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
115 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
116 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
117 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000118 uint32x4x4_t res =
119 {
120 {
121 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
122 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
123 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
124 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
125 }
126 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000127
128 return res;
129}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100130
131// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000132template <typename T>
133inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
134 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
135 calculate_min(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100136{
137 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
138 return wrapper::vpmin(pmin, pmin);
139}
140
Luca Foschianiee939fb2020-01-28 10:38:07 +0000141// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
142template <typename T>
143inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
144 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
145 calculate_min(T in)
146{
147 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
148 pmin = wrapper::vpmin(pmin, pmin);
149 pmin = wrapper::vpmin(pmin, pmin);
150 return wrapper::vpmin(pmin, pmin);
151}
152
Usama Arifa4a08ad2019-05-20 12:38:33 +0100153// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000154template <typename T>
155inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
156 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
157 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100158{
159 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
160 return wrapper::vpmax(pmax, pmax);
161}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100162
163// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000164template <typename T>
165inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
166 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
167 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100168{
169 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000170 pmax = wrapper::vpmax(pmax, pmax);
171 pmax = wrapper::vpmax(pmax, pmax);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100172 return wrapper::vpmax(pmax, pmax);
173}
174
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100175template <typename T>
176uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000177{
178 uint32x4_t res_idx_mask{ 0 };
179 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
180
181 if(op == ReductionOperation::ARG_IDX_MIN)
182 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100183 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000184 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
185 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
186 }
187 else
188 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100189 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100190 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000191 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
192 }
193
194 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
195 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
196 pmin = wrapper::vpmin(pmin, pmin);
197 uint32_t res = wrapper::vgetlane(pmin, 0);
198
199 return (res - 0xFFFFFFFF);
200}
201
Luca Foschianiee939fb2020-01-28 10:38:07 +0000202template <typename T>
203uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000204{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000205 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000206 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
207 uint8x16_t mask_u8{ 0 };
208 if(op == ReductionOperation::ARG_IDX_MIN)
209 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100210 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000211 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
212 }
213 else
214 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100215 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000216 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
217 }
218
219 // Widen vectors
220 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
221 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
222 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
223 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
224 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
225 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
226 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
227 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
228 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
229 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
230 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
231 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
232 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
233 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
234
235 uint32_t res = 0xFFFFFFFF;
236 int iter = 0;
237 do
238 {
239 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
240 pmin = wrapper::vpmin(pmin, pmin);
241 res = std::min(wrapper::vgetlane(pmin, 0), res);
242 iter++;
243 }
244 while(iter < 4);
245
246 return (res - 0xFFFFFFFF);
247}
Luca Foschianiee939fb2020-01-28 10:38:07 +0000248
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000249#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100250template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000251uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
252{
253 uint32x4x2_t mask{ 0 };
254 uint16x8_t mask_u16{ 0 };
255 if(op == ReductionOperation::ARG_IDX_MIN)
256 {
257 mask_u16 = wrapper::vcgt(b, a);
258 }
259 else
260 {
261 mask_u16 = wrapper::vclt(b, a);
262 }
263 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
264 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
265 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
266 { idx + 4, idx + 5, idx + 6, idx + 7 }
267 }
268 };
269 if(axis != 0)
270 {
271 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
272 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
273 }
274 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
275 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
276 0, 0
277 };
278
279 return res;
280}
281
Usama Arifa4a08ad2019-05-20 12:38:33 +0100282// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
283inline float16x4_t calculate_min(float16x8_t in)
284{
285 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
286 pmin = wrapper::vpmin(pmin, pmin);
287 return wrapper::vpmin(pmin, pmin);
288}
289// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
290inline float16x4_t calculate_max(float16x8_t in)
291{
292 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
293 pmax = wrapper::vpmax(pmax, pmax);
294 return wrapper::vpmax(pmax, pmax);
295}
296
Usama Arif0a5a57a2019-05-23 14:20:33 +0100297template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000298uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
299{
300 uint32x4x2_t res_idx_mask{ 0 };
301 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
302 uint16x8_t mask_u16;
303 if(op == ReductionOperation::ARG_IDX_MIN)
304 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100305 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000306 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
307 }
308 else
309 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100310 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000311 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
312 }
313
314 // Widen vectors
315 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
316 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
317 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
318 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
319 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
320 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
321
322 uint32_t res = 0xFFFFFFFF;
323 int iter = 0;
324 do
325 {
326 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
327 pmin = wrapper::vpmin(pmin, pmin);
328 res = std::min(wrapper::vgetlane(pmin, 0), res);
329 iter++;
330 }
331 while(iter < 2);
332
333 return (res - 0xFFFFFFFF);
334}
335#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
336
Georgios Pinitasd9769582017-08-03 10:19:40 +0100337template <class F>
338class Reducer
339{
340public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000341 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100342 {
343 // Set out window
344 Window out_window(window);
345 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
346
347 // Get first input and output slices
348 Window in_slice = window.first_slice_window_1D();
349 Window out_slice = out_window.first_slice_window_1D();
350
351 do
352 {
353 Iterator in(input, in_slice);
354 Iterator out(output, out_slice);
355
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000356 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100357 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
359 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000360 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100361 {
362 // Set in window
363 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000364 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100365
366 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000367 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100368
369 // Get first input and output slices
370 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000371 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100372
373 do
374 {
375 Iterator in(input, in_slice);
376 Iterator out(output, out_slice);
377
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000378 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100379 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000380 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100381 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000382 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100383 {
384 // Set in window
385 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000386 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100387
388 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000389 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100390
391 // Get first input and output slices
392 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000393 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100394
395 do
396 {
397 Iterator in(input, in_slice);
398 Iterator out(output, out_slice);
399
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000400 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100401 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000402 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100403 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000404 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100405 {
406 // Set in/out window
407 Window in_window(window);
408 Window out_window(window);
409
410 in_window.set(3, Window::Dimension(0, 1, 1));
411 out_window.set(3, Window::Dimension(0, 1, 1));
412
413 // Get first input and output slices
414 Window in_slice = in_window.first_slice_window_4D();
415 Window out_slice = out_window.first_slice_window_4D();
416
417 do
418 {
419 Iterator in(input, in_slice);
420 Iterator out(output, out_slice);
421
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000422 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100423 }
424 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100425 }
426};
427
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000428template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100429struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100430{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100431 /** NEON vector tag type. */
432 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
433
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000434 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100435 {
436 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000437 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100438 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000439 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100440 case ReductionOperation::ARG_IDX_MAX:
441 case ReductionOperation::ARG_IDX_MIN:
442 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100443 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100444 {
445 init_res_value = *reinterpret_cast<T *>(input.ptr());
446 break;
447 }
448 case ReductionOperation::PROD:
449 {
450 init_res_value = static_cast<T>(1.f);
451 break;
452 }
453 default:
454 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000455 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000456 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000457 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100458
459 execute_window_loop(in_slice, [&](const Coordinates & id)
460 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100461 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
462 const auto vec_elements = wrapper::vloadq(in_ptr);
463
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000464 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100465 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000466 case ReductionOperation::SUM_SQUARE:
467 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
468 break;
469 case ReductionOperation::MEAN_SUM:
470 case ReductionOperation::SUM:
471 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
472 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000473 case ReductionOperation::PROD:
474 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
475 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000476 case ReductionOperation::ARG_IDX_MIN:
477 {
478 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100479 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000480 vec_res_value = temp_vec_res_value;
481 break;
482 }
483 case ReductionOperation::ARG_IDX_MAX:
484 {
485 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100486 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000487 vec_res_value = temp_vec_res_value;
488 break;
489 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100490 case ReductionOperation::MIN:
491 {
492 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
493 break;
494 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100495 case ReductionOperation::MAX:
496 {
497 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
498 break;
499 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000500 default:
501 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100502 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100503 },
504 input);
505
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000506 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000507 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000508 case ReductionOperation::SUM:
509 case ReductionOperation::SUM_SQUARE:
510 case ReductionOperation::MEAN_SUM:
511 {
512 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
513 for(int i = 0; i < S / 4; ++i)
514 {
515 carry_res = wrapper::vpadd(carry_res, carry_res);
516 }
517 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100518
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000519 if(op == ReductionOperation::MEAN_SUM)
520 {
521 res /= in_info.dimension(0);
522 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100523
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000524 *(reinterpret_cast<T *>(output.ptr())) = res;
525 break;
526 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000527 case ReductionOperation::PROD:
528 {
529 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
530 T res = 1;
531 for(int i = 0; i < S / 2; ++i)
532 {
533 res *= wrapper::vgetlane(carry_res, i);
534 }
535 *(reinterpret_cast<T *>(output.ptr())) = res;
536 break;
537 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000538 case ReductionOperation::ARG_IDX_MIN:
539 case ReductionOperation::ARG_IDX_MAX:
540 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100541 auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000542 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
543 break;
544 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100545 case ReductionOperation::MIN:
546 {
547 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
548 break;
549 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100550 case ReductionOperation::MAX:
551 {
552 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0);
553 break;
554 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000555 default:
556 ARM_COMPUTE_ERROR("Not supported");
557 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100558 }
559};
560
Luca Foschianiee939fb2020-01-28 10:38:07 +0000561template <typename T>
562struct RedOpX_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100563{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000564 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100565 {
566 ARM_COMPUTE_UNUSED(out_slice);
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100567
Luca Foschianiee939fb2020-01-28 10:38:07 +0000568 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
569
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100570 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
571
Luca Foschianiee939fb2020-01-28 10:38:07 +0000572 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
573 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
574 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
575 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100576
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000577 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
578 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
579 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
580 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
581
Luca Foschianiee939fb2020-01-28 10:38:07 +0000582 typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000583
Usama Arif28f0dd92019-05-20 13:44:34 +0100584 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000585 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000586 vec_res_value = wrapper::vdup_n(*reinterpret_cast<T *>(input.ptr()), wrapper::traits::vector_128_tag{});
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000587 }
588
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000589 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100590 execute_window_loop(in_slice, [&](const Coordinates & id)
591 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000592 const auto vec_elements = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000593 switch(op)
594 {
595 case ReductionOperation::SUM:
596 case ReductionOperation::MEAN_SUM:
597 {
598 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
599 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100600
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000601 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
602 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
603 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
604 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100605
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000606 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
607 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
608 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
609 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
610 break;
611 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000612 case ReductionOperation::PROD:
613 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100614 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
615 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000616
Luca Foschianiee939fb2020-01-28 10:38:07 +0000617 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
618 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000619
Luca Foschianiee939fb2020-01-28 10:38:07 +0000620 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
621 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
622 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
623 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000624
Luca Foschianiee939fb2020-01-28 10:38:07 +0000625 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
626 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
627 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
628 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000629
630 //de-quantize vec_elements
631 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
632 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
633 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
634 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
635
636 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
637 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
638 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
639 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
640 break;
641 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000642 case ReductionOperation::ARG_IDX_MIN:
643 {
644 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000645 vec_res_idx = calculate_index_quantized(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000646 vec_res_value = temp_vec_res_value;
647 break;
648 }
649 case ReductionOperation::ARG_IDX_MAX:
650 {
651 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000652 vec_res_idx = calculate_index_quantized(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000653 vec_res_value = temp_vec_res_value;
654 break;
655 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100656 case ReductionOperation::MIN:
657 {
658 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
659 break;
660 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100661 case ReductionOperation::MAX:
662 {
663 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
664 break;
665 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000666 default:
667 ARM_COMPUTE_ERROR("Not supported");
668 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100669 },
670 input);
671
Usama Arifa4a08ad2019-05-20 12:38:33 +0100672 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100673 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100674 case ReductionOperation::ARG_IDX_MIN:
675 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000676 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000677 auto res = calculate_vector_index_quantized(vec_res_idx, vec_res_value, op);
678 *(reinterpret_cast<PromotedType *>(output.ptr())) = res;
Usama Arifa4a08ad2019-05-20 12:38:33 +0100679 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000680 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100681 case ReductionOperation::MIN:
682 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000683 *(output.ptr()) = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
Usama Arifa4a08ad2019-05-20 12:38:33 +0100684 break;
685 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100686 case ReductionOperation::MAX:
687 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000688 *(output.ptr()) = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
Usama Arif28f0dd92019-05-20 13:44:34 +0100689 break;
690 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100691 case ReductionOperation::PROD:
692 {
693 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
694 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
695 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000696
Usama Arifa4a08ad2019-05-20 12:38:33 +0100697 float res = wrapper::vgetlane(carry_res, 0);
698 res *= wrapper::vgetlane(carry_res, 1);
699 res *= wrapper::vgetlane(carry_res, 2);
700 res *= wrapper::vgetlane(carry_res, 3);
701
702 //re-quantize result
Luca Foschianiee939fb2020-01-28 10:38:07 +0000703 if(std::is_same<T, uint8_t>::value)
704 {
705 res = quantize_qasymm8(res, iq_info);
706 }
707 else
708 {
709 res = quantize_qasymm8_signed(res, iq_info);
710 }
711
712 *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100713 break;
714 }
715 default:
716 {
717 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
718 carry_res = wrapper::vadd(carry_res, vec_res_value3);
719 carry_res = wrapper::vadd(carry_res, vec_res_value4);
720
721 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
722 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000723 auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
Usama Arifa4a08ad2019-05-20 12:38:33 +0100724
725 if(op == ReductionOperation::MEAN_SUM)
726 {
727 res /= in_info.dimension(0);
728 }
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000729 else
730 {
731 // Subtract accumulated offsets
732 res -= (in_info.dimension(0) - 1) * iq_info.offset;
733 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100734
Luca Foschianiee939fb2020-01-28 10:38:07 +0000735 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100736 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000737 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100738 }
739};
740
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000741template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100742struct RedOpYZW
743{
744 /** NEON vector tag type. */
745 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000746 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100747
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000748 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100749 {
750 ARM_COMPUTE_UNUSED(out_slice);
751
giuros01154bc1c2019-03-26 17:44:40 +0000752 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100753 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000754 neon_vector vec_res_value = { 0 };
Usama Arifa4a08ad2019-05-20 12:38:33 +0100755 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000756 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100757 case ReductionOperation::ARG_IDX_MAX:
758 case ReductionOperation::ARG_IDX_MIN:
759 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100760 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100761 {
762 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
763 break;
764 }
765 case ReductionOperation::PROD:
766 {
767 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
768 break;
769 }
770 default:
771 {
772 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
773 break;
774 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000775 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000776 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000777
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100778 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
779 {
Pablo Tello07958782020-01-09 14:43:36 +0000780 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.strides_in_bytes()[axis] * dim);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100781 const auto vec_elements = wrapper::vloadq(in_ptr);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000782 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100783 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000784 case ReductionOperation::SUM:
785 case ReductionOperation::MEAN_SUM:
786 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
787 break;
788 case ReductionOperation::SUM_SQUARE:
789 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
790 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000791 case ReductionOperation::PROD:
792 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
793 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000794 case ReductionOperation::ARG_IDX_MIN:
795 {
796 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
797 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
798 vec_res_value = temp_vec_res_value;
799 break;
800 }
801 case ReductionOperation::ARG_IDX_MAX:
802 {
803 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
804 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
805 vec_res_value = temp_vec_res_value;
806 break;
807 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100808 case ReductionOperation::MIN:
809 {
810 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
811 break;
812 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100813 case ReductionOperation::MAX:
814 {
815 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
816 break;
817 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000818 default:
819 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100820 }
821 }
822
823 if(op == ReductionOperation::MEAN_SUM)
824 {
825 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000826 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100827 }
828
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000829 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
830 {
831 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100832#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
833 if(std::is_same<T, float16_t>::value)
834 {
835 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
836 }
837#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000838 }
839 else
840 {
841 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
842 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100843 },
844 input, output);
845 }
846};
847
giuros01154bc1c2019-03-26 17:44:40 +0000848template <typename T, int S, int axis, ReductionOperation op>
849struct RedOpYZW_complex
850{
851 /** NEON vector tag type. */
852 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
853 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
854
855 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
856 {
857 ARM_COMPUTE_UNUSED(out_slice);
858 ARM_COMPUTE_ERROR_ON(axis != 2);
859
860 const size_t stride_z = in_info.strides_in_bytes()[axis];
861
862 execute_window_loop(in_slice, [&](const Coordinates &)
863 {
864 neon_vector vec_res_value_0 = { 0 };
865 neon_vector vec_res_value_1 = { 0 };
866
867 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
868 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
869
870 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
871 {
872 T *in_ptr_0;
873 T *in_ptr_1;
874 switch(axis)
875 {
876 case 2:
877 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
878 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
879 break;
880 default:
881 ARM_COMPUTE_ERROR("Not supported");
882 }
883 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
884 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
885
886 switch(op)
887 {
888 case ReductionOperation::SUM:
889 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
890 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
891 break;
892 default:
893 ARM_COMPUTE_ERROR("Not supported");
894 }
895 }
896
897 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
898 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
899
900 },
901 input, output);
902 }
903};
904
Luca Foschianiee939fb2020-01-28 10:38:07 +0000905template <typename T>
906struct RedOpYZW_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100907{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000908 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100909 {
910 ARM_COMPUTE_UNUSED(out_slice);
911
Luca Foschianiee939fb2020-01-28 10:38:07 +0000912 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
913
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100914 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
915
giuros01154bc1c2019-03-26 17:44:40 +0000916 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100917 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000918 uint32x4x4_t vec_res_idx{ { 0 } };
Luca Foschianiee939fb2020-01-28 10:38:07 +0000919 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
920 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
921 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
922 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000923
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000924 auto vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
925 auto vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
926 auto vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
927 auto vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000928
Luca Foschianiee939fb2020-01-28 10:38:07 +0000929 auto vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000930
931 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100932 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000933 const T *in_ptr = reinterpret_cast<T *>(input.ptr()) + in_info.strides_in_bytes()[axis] * index_dim;
934 const auto vec_elements = wrapper::vloadq(in_ptr);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000935 switch(op)
936 {
937 case ReductionOperation::SUM:
938 case ReductionOperation::MEAN_SUM:
939 {
940 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
941 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100942
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000943 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
944 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
945 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
946 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100947
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000948 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
949 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
950 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
951 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
952 break;
953 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000954 case ReductionOperation::PROD:
955 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000956 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
957 const auto scale32x4f_4 = wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000958
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000959 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
960 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000961
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000962 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
963 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
964 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
965 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000966
Luca Foschianiee939fb2020-01-28 10:38:07 +0000967 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
968 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
969 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
970 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000971
972 //de-quantize vec_elements
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000973 temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
974 temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
975 temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
976 temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000977
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000978 vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
979 vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
980 vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
981 vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000982 break;
983 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000984 case ReductionOperation::ARG_IDX_MIN:
985 {
986 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000987 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000988 vec_res_value = temp_vec_res_value;
989 break;
990 }
991 case ReductionOperation::ARG_IDX_MAX:
992 {
993 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000994 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000995 vec_res_value = temp_vec_res_value;
996 break;
997 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100998 case ReductionOperation::MIN:
999 {
1000 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1001 break;
1002 }
Usama Arif28f0dd92019-05-20 13:44:34 +01001003 case ReductionOperation::MAX:
1004 {
1005 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1006 break;
1007 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001008 default:
1009 ARM_COMPUTE_ERROR("Not supported");
1010 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001011 }
1012
1013 if(op == ReductionOperation::MEAN_SUM)
1014 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001015 const auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<float>(in_info.dimension(axis)), wrapper::traits::vector_128_tag{}));
Luca Foschianiee939fb2020-01-28 10:38:07 +00001016 vec_res_value1_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value1), vec_width_inv);
1017 vec_res_value2_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value2), vec_width_inv);
1018 vec_res_value3_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value3), vec_width_inv);
1019 vec_res_value4_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001020
Luca Foschianiee939fb2020-01-28 10:38:07 +00001021 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1022 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1023 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1024 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001025 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001026 else if(op == ReductionOperation::PROD)
1027 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001028 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001029 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001030
1031 //re-quantize
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001032 vec_res_value1_f = wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1033 vec_res_value2_f = wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1034 vec_res_value3_f = wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1035 vec_res_value4_f = wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001036
Luca Foschianiee939fb2020-01-28 10:38:07 +00001037 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1038 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1039 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1040 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001041 }
1042
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001043 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1044 {
1045 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
1046 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
1047 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
1048 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
1049 }
Usama Arifa4a08ad2019-05-20 12:38:33 +01001050 else if(op == ReductionOperation::ARG_IDX_MIN)
1051 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001052 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
Usama Arifa4a08ad2019-05-20 12:38:33 +01001053 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001054 else
1055 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001056 if(op == ReductionOperation::SUM)
1057 {
1058 // Subtract offsets
1059 auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1060
Luca Foschianiee939fb2020-01-28 10:38:07 +00001061 auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1062 auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1063 auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1064 auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001065
1066 vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1067 vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1068 vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1069 vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1070
1071 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1072 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
Luca Foschianiee939fb2020-01-28 10:38:07 +00001073
1074 combine_and_store<T>(temp16x8t_1, temp16x8t_2, output);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001075 }
1076 else
1077 {
1078 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1079 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1080 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
Luca Foschianiee939fb2020-01-28 10:38:07 +00001081
1082 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), res);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001083 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001084 }
1085
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001086 },
1087 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001088 }
1089};
1090
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001091void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001092{
giuros01154bc1c2019-03-26 17:44:40 +00001093 const bool is_complex = (input->info()->num_channels() == 2);
1094
1095 if(is_complex)
1096 {
1097 switch(axis)
1098 {
1099 case 2:
1100 switch(input->info()->data_type())
1101 {
1102 case DataType::F32:
1103 switch(op)
1104 {
1105 case ReductionOperation::SUM:
1106 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1107 default:
1108 ARM_COMPUTE_ERROR("Not supported");
1109 }
1110 default:
1111 ARM_COMPUTE_ERROR("Not supported");
1112 }
1113 default:
1114 ARM_COMPUTE_ERROR("Not supported");
1115 }
1116 }
1117
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001118 switch(axis)
1119 {
1120 case 0:
1121 switch(input->info()->data_type())
1122 {
1123 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001124 return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
1125 case DataType::QASYMM8_SIGNED:
1126 return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001127#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1128 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001129 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001130#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1131 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001132 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001133 case DataType::S32:
1134 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001135 default:
1136 ARM_COMPUTE_ERROR("Not supported");
1137 }
1138 case 1:
1139 switch(input->info()->data_type())
1140 {
1141 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001142 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1143 case DataType::QASYMM8_SIGNED:
1144 return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001145#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1146 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001147 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001148#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1149 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001150 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001151 case DataType::S32:
1152 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001153 default:
1154 ARM_COMPUTE_ERROR("Not supported");
1155 }
1156 case 2:
1157 switch(input->info()->data_type())
1158 {
1159 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001160 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1161 case DataType::QASYMM8_SIGNED:
1162 return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001163#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1164 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001165 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001166#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1167 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001168 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001169 case DataType::S32:
1170 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001171 default:
1172 ARM_COMPUTE_ERROR("Not supported");
1173 }
1174 case 3:
1175 switch(input->info()->data_type())
1176 {
1177 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001178 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1179 case DataType::QASYMM8_SIGNED:
1180 return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001181#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1182 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001183 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001184#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1185 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001186 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001187 case DataType::S32:
1188 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001189 default:
1190 ARM_COMPUTE_ERROR("Not supported");
1191 }
1192 default:
1193 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1194 }
1195}
John Richardson73d4aef2018-05-08 14:34:33 +01001196
1197Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1198{
1199 ARM_COMPUTE_UNUSED(op);
1200
1201 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001202 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001203
1204 if(input->num_channels() == 1)
1205 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001206 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001207 }
1208 else
1209 {
1210 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1211 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1212 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1213 }
John Richardson73d4aef2018-05-08 14:34:33 +01001214
1215 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001216 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001217
1218 if(output->total_size() != 0)
1219 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001220 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1221 if(!is_arg_min_max)
1222 {
1223 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001224 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001225 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001226 }
1227 else
1228 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001229 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001230 }
John Richardson73d4aef2018-05-08 14:34:33 +01001231
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001232 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001233 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1234 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1235 }
1236
1237 return Status{};
1238}
1239
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001240std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001241{
1242 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001243 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001244
1245 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001246 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
Sang-Hoon Parkeaa01ab2019-11-11 17:33:28 +00001247 DataType output_data_type = is_arg_min_max ? DataType::S32 : input->data_type();
giuros01154bc1c2019-03-26 17:44:40 +00001248 auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001249
1250 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
1251
1252 // Configure kernel window
1253 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
1254 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
1255 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
1256
1257 bool window_changed = update_window_and_padding(win, input_access, output_access);
1258 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1259
1260 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1261
1262 return std::make_tuple(err, win);
1263}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001264} // namespace
1265
1266NEReductionOperationKernel::NEReductionOperationKernel()
1267 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
1268{
1269}
1270
1271BorderSize NEReductionOperationKernel::border_size() const
1272{
1273 return _border_size;
1274}
1275
1276void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1277{
1278 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001279
John Richardson73d4aef2018-05-08 14:34:33 +01001280 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001281
1282 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
1283
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001284 _input = input;
1285 _output = output;
1286 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
1287 _op = op;
1288 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001289
1290 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001291 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001292
John Richardson73d4aef2018-05-08 14:34:33 +01001293 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001294
John Richardson73d4aef2018-05-08 14:34:33 +01001295 INEKernel::configure(std::get<1>(win_config));
1296}
1297
1298Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1299{
1300 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001301 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001302
1303 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001304}
1305
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001306void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001307{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001308 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001309 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1310 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1311
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001312 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001313}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001314} // namespace arm_compute