blob: 5a52216eacf5bf2b521a40fc3101decb5fb46d07 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
32#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010033#include "arm_compute/core/TensorInfo.h"
Luca Foschianiee939fb2020-01-28 10:38:07 +000034#include "arm_compute/core/Utils.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010035#include "arm_compute/core/Validate.h"
Michalis Spyrou19bd4122020-01-22 10:27:06 +000036#include "arm_compute/core/utils/misc/SaturateCast.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038
Michalis Spyroubcf8a962018-10-12 10:51:31 +010039#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010040#include <arm_neon.h>
41
Michalis Spyroubcf8a962018-10-12 10:51:31 +010042namespace arm_compute
43{
Georgios Pinitasd9769582017-08-03 10:19:40 +010044namespace
45{
Luca Foschianiee939fb2020-01-28 10:38:07 +000046// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
47template <typename T>
48void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output)
49{
50 if(std::is_same<T, uint8_t>::value)
51 {
52 auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
53 wrapper::vstore(output.ptr(), res);
54 }
55 else
56 {
57 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
58 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr()), res);
59 }
60}
61
Michalis Spyroub9626ab2019-05-13 17:41:01 +010062template <typename T>
63uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000064{
65 uint32x4_t mask{ 0 };
66 if(op == ReductionOperation::ARG_IDX_MIN)
67 {
68 mask = wrapper::vcgt(b, a);
69 }
70 else
71 {
72 mask = wrapper::vclt(b, a);
73 }
74
75 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
76 if(axis != 0)
77 {
78 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
79 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000080 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000081
82 return res;
83}
84
Luca Foschianiee939fb2020-01-28 10:38:07 +000085template <typename T>
86uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000087{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000088 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000089 uint8x16_t mask_u8{ 0 };
90 if(op == ReductionOperation::ARG_IDX_MIN)
91 {
92 mask_u8 = wrapper::vcgt(b, a);
93 }
94 else
95 {
96 mask_u8 = wrapper::vclt(b, a);
97 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000098 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
99 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
100 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
101 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
102 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
103 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
104
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000105 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
106 { idx + 4, idx + 5, idx + 6, idx + 7 },
107 { idx + 8, idx + 9, idx + 10, idx + 11 },
108 { idx + 12, idx + 13, idx + 14, idx + 15 }
109 }
110 };
111 if(axis != 0)
112 {
113 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
114 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
115 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
116 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
117 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000118 uint32x4x4_t res =
119 {
120 {
121 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
122 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
123 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
124 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
125 }
126 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000127
128 return res;
129}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100130
131// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000132template <typename T>
133inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
134 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
135 calculate_min(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100136{
137 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
138 return wrapper::vpmin(pmin, pmin);
139}
140
Luca Foschianiee939fb2020-01-28 10:38:07 +0000141// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
142template <typename T>
143inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
144 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
145 calculate_min(T in)
146{
147 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
148 pmin = wrapper::vpmin(pmin, pmin);
149 pmin = wrapper::vpmin(pmin, pmin);
150 return wrapper::vpmin(pmin, pmin);
151}
152
Usama Arifa4a08ad2019-05-20 12:38:33 +0100153// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000154template <typename T>
155inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
156 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
157 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100158{
159 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
160 return wrapper::vpmax(pmax, pmax);
161}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100162
163// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000164template <typename T>
165inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
166 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
167 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100168{
169 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000170 pmax = wrapper::vpmax(pmax, pmax);
171 pmax = wrapper::vpmax(pmax, pmax);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100172 return wrapper::vpmax(pmax, pmax);
173}
174
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100175template <typename T>
176uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000177{
178 uint32x4_t res_idx_mask{ 0 };
179 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
180
181 if(op == ReductionOperation::ARG_IDX_MIN)
182 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100183 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000184 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
185 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
186 }
187 else
188 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100189 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100190 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000191 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
192 }
193
194 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
195 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
196 pmin = wrapper::vpmin(pmin, pmin);
197 uint32_t res = wrapper::vgetlane(pmin, 0);
198
199 return (res - 0xFFFFFFFF);
200}
201
Luca Foschianiee939fb2020-01-28 10:38:07 +0000202template <typename T>
203uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000204{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000205 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000206 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
207 uint8x16_t mask_u8{ 0 };
208 if(op == ReductionOperation::ARG_IDX_MIN)
209 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100210 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000211 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
212 }
213 else
214 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100215 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000216 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
217 }
218
219 // Widen vectors
220 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
221 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
222 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
223 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
224 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
225 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
226 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
227 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
228 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
229 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
230 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
231 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
232 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
233 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
234
235 uint32_t res = 0xFFFFFFFF;
236 int iter = 0;
237 do
238 {
239 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
240 pmin = wrapper::vpmin(pmin, pmin);
241 res = std::min(wrapper::vgetlane(pmin, 0), res);
242 iter++;
243 }
244 while(iter < 4);
245
246 return (res - 0xFFFFFFFF);
247}
Luca Foschianiee939fb2020-01-28 10:38:07 +0000248
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000249#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100250template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000251uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
252{
253 uint32x4x2_t mask{ 0 };
254 uint16x8_t mask_u16{ 0 };
255 if(op == ReductionOperation::ARG_IDX_MIN)
256 {
257 mask_u16 = wrapper::vcgt(b, a);
258 }
259 else
260 {
261 mask_u16 = wrapper::vclt(b, a);
262 }
263 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
264 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
265 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
266 { idx + 4, idx + 5, idx + 6, idx + 7 }
267 }
268 };
269 if(axis != 0)
270 {
271 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
272 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
273 }
274 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
275 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
276 0, 0
277 };
278
279 return res;
280}
281
Usama Arifa4a08ad2019-05-20 12:38:33 +0100282// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
283inline float16x4_t calculate_min(float16x8_t in)
284{
285 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
286 pmin = wrapper::vpmin(pmin, pmin);
287 return wrapper::vpmin(pmin, pmin);
288}
289// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
290inline float16x4_t calculate_max(float16x8_t in)
291{
292 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
293 pmax = wrapper::vpmax(pmax, pmax);
294 return wrapper::vpmax(pmax, pmax);
295}
296
Usama Arif0a5a57a2019-05-23 14:20:33 +0100297template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000298uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
299{
300 uint32x4x2_t res_idx_mask{ 0 };
301 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
302 uint16x8_t mask_u16;
303 if(op == ReductionOperation::ARG_IDX_MIN)
304 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100305 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000306 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
307 }
308 else
309 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100310 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000311 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
312 }
313
314 // Widen vectors
315 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
316 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
317 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
318 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
319 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
320 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
321
322 uint32_t res = 0xFFFFFFFF;
323 int iter = 0;
324 do
325 {
326 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
327 pmin = wrapper::vpmin(pmin, pmin);
328 res = std::min(wrapper::vgetlane(pmin, 0), res);
329 iter++;
330 }
331 while(iter < 2);
332
333 return (res - 0xFFFFFFFF);
334}
335#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
336
Georgios Pinitasd9769582017-08-03 10:19:40 +0100337template <class F>
338class Reducer
339{
340public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000341 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100342 {
343 // Set out window
344 Window out_window(window);
345 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
346
347 // Get first input and output slices
348 Window in_slice = window.first_slice_window_1D();
349 Window out_slice = out_window.first_slice_window_1D();
350
351 do
352 {
353 Iterator in(input, in_slice);
354 Iterator out(output, out_slice);
355
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000356 f(in, out, in_slice, out_slice, *input->info(), op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100357 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
359 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000360 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100361 {
362 // Set in window
363 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000364 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100365
366 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000367 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100368
369 // Get first input and output slices
370 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000371 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100372
373 do
374 {
375 Iterator in(input, in_slice);
376 Iterator out(output, out_slice);
377
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000378 f(in, out, in_slice, out_slice, *input->info(), 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100379 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000380 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100381 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000382 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100383 {
384 // Set in window
385 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000386 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100387
388 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000389 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100390
391 // Get first input and output slices
392 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +0000393 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100394
395 do
396 {
397 Iterator in(input, in_slice);
398 Iterator out(output, out_slice);
399
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000400 f(in, out, in_slice, out_slice, *input->info(), 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100401 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000402 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100403 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000404 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100405 {
406 // Set in/out window
407 Window in_window(window);
408 Window out_window(window);
409
410 in_window.set(3, Window::Dimension(0, 1, 1));
411 out_window.set(3, Window::Dimension(0, 1, 1));
412
413 // Get first input and output slices
414 Window in_slice = in_window.first_slice_window_4D();
415 Window out_slice = out_window.first_slice_window_4D();
416
417 do
418 {
419 Iterator in(input, in_slice);
420 Iterator out(output, out_slice);
421
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000422 f(in, out, in_slice, out_slice, *input->info(), 3, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100423 }
424 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100425 }
426};
427
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000428template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100429struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100430{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100431 /** NEON vector tag type. */
432 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
433
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000434 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100435 {
436 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000437 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100438 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000439 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100440 case ReductionOperation::ARG_IDX_MAX:
441 case ReductionOperation::ARG_IDX_MIN:
442 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100443 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100444 {
445 init_res_value = *reinterpret_cast<T *>(input.ptr());
446 break;
447 }
448 case ReductionOperation::PROD:
449 {
450 init_res_value = static_cast<T>(1.f);
451 break;
452 }
453 default:
454 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000455 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000456 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000457 uint32x4x4_t vec_res_idx{ { 0 } };
Georgios Pinitasd9769582017-08-03 10:19:40 +0100458
459 execute_window_loop(in_slice, [&](const Coordinates & id)
460 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100461 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
462 const auto vec_elements = wrapper::vloadq(in_ptr);
463
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000464 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100465 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000466 case ReductionOperation::SUM_SQUARE:
467 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
468 break;
469 case ReductionOperation::MEAN_SUM:
470 case ReductionOperation::SUM:
471 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
472 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000473 case ReductionOperation::PROD:
474 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
475 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000476 case ReductionOperation::ARG_IDX_MIN:
477 {
478 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100479 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000480 vec_res_value = temp_vec_res_value;
481 break;
482 }
483 case ReductionOperation::ARG_IDX_MAX:
484 {
485 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100486 vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000487 vec_res_value = temp_vec_res_value;
488 break;
489 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100490 case ReductionOperation::MIN:
491 {
492 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
493 break;
494 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100495 case ReductionOperation::MAX:
496 {
497 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
498 break;
499 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000500 default:
501 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100502 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100503 },
504 input);
505
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000506 switch(op)
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000507 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000508 case ReductionOperation::SUM:
509 case ReductionOperation::SUM_SQUARE:
510 case ReductionOperation::MEAN_SUM:
511 {
512 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
513 for(int i = 0; i < S / 4; ++i)
514 {
515 carry_res = wrapper::vpadd(carry_res, carry_res);
516 }
517 auto res = wrapper::vgetlane(carry_res, 0);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100518
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000519 if(op == ReductionOperation::MEAN_SUM)
520 {
521 res /= in_info.dimension(0);
522 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100523
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000524 *(reinterpret_cast<T *>(output.ptr())) = res;
525 break;
526 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000527 case ReductionOperation::PROD:
528 {
529 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
530 T res = 1;
531 for(int i = 0; i < S / 2; ++i)
532 {
533 res *= wrapper::vgetlane(carry_res, i);
534 }
535 *(reinterpret_cast<T *>(output.ptr())) = res;
536 break;
537 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000538 case ReductionOperation::ARG_IDX_MIN:
539 case ReductionOperation::ARG_IDX_MAX:
540 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100541 auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000542 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
543 break;
544 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100545 case ReductionOperation::MIN:
546 {
547 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
548 break;
549 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100550 case ReductionOperation::MAX:
551 {
552 *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0);
553 break;
554 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000555 default:
556 ARM_COMPUTE_ERROR("Not supported");
557 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100558 }
559};
560
Luca Foschianiee939fb2020-01-28 10:38:07 +0000561template <typename T>
562struct RedOpX_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100563{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000564 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100565 {
566 ARM_COMPUTE_UNUSED(out_slice);
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100567
Luca Foschianiee939fb2020-01-28 10:38:07 +0000568 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
569
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100570 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
571
Luca Foschianiee939fb2020-01-28 10:38:07 +0000572 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
573 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
574 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
575 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100576
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000577 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
578 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
579 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
580 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
581
Luca Foschianiee939fb2020-01-28 10:38:07 +0000582 typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = { 0 };
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000583
Usama Arif28f0dd92019-05-20 13:44:34 +0100584 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000585 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000586 vec_res_value = wrapper::vdup_n(*reinterpret_cast<T *>(input.ptr()), wrapper::traits::vector_128_tag{});
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000587 }
588
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000589 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100590 execute_window_loop(in_slice, [&](const Coordinates & id)
591 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000592 const auto vec_elements = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000593 switch(op)
594 {
595 case ReductionOperation::SUM:
596 case ReductionOperation::MEAN_SUM:
597 {
598 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
599 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100600
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000601 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
602 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
603 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
604 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100605
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000606 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
607 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
608 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
609 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
610 break;
611 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000612 case ReductionOperation::PROD:
613 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100614 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
615 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000616
Luca Foschianiee939fb2020-01-28 10:38:07 +0000617 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
618 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000619
Luca Foschianiee939fb2020-01-28 10:38:07 +0000620 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
621 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
622 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
623 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000624
Luca Foschianiee939fb2020-01-28 10:38:07 +0000625 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
626 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
627 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
628 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000629
630 //de-quantize vec_elements
631 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
632 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
633 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
634 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
635
636 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
637 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
638 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
639 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
640 break;
641 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000642 case ReductionOperation::ARG_IDX_MIN:
643 {
644 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000645 vec_res_idx = calculate_index_quantized(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000646 vec_res_value = temp_vec_res_value;
647 break;
648 }
649 case ReductionOperation::ARG_IDX_MAX:
650 {
651 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000652 vec_res_idx = calculate_index_quantized(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000653 vec_res_value = temp_vec_res_value;
654 break;
655 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100656 case ReductionOperation::MIN:
657 {
658 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
659 break;
660 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100661 case ReductionOperation::MAX:
662 {
663 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
664 break;
665 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000666 default:
667 ARM_COMPUTE_ERROR("Not supported");
668 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100669 },
670 input);
671
Usama Arifa4a08ad2019-05-20 12:38:33 +0100672 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100673 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100674 case ReductionOperation::ARG_IDX_MIN:
675 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000676 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000677 auto res = calculate_vector_index_quantized(vec_res_idx, vec_res_value, op);
678 *(reinterpret_cast<PromotedType *>(output.ptr())) = res;
Usama Arifa4a08ad2019-05-20 12:38:33 +0100679 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000680 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100681 case ReductionOperation::MIN:
682 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000683 *(output.ptr()) = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
Usama Arifa4a08ad2019-05-20 12:38:33 +0100684 break;
685 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100686 case ReductionOperation::MAX:
687 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000688 *(output.ptr()) = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
Usama Arif28f0dd92019-05-20 13:44:34 +0100689 break;
690 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100691 case ReductionOperation::PROD:
692 {
693 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
694 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
695 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000696
Usama Arifa4a08ad2019-05-20 12:38:33 +0100697 float res = wrapper::vgetlane(carry_res, 0);
698 res *= wrapper::vgetlane(carry_res, 1);
699 res *= wrapper::vgetlane(carry_res, 2);
700 res *= wrapper::vgetlane(carry_res, 3);
701
702 //re-quantize result
Luca Foschianiee939fb2020-01-28 10:38:07 +0000703 if(std::is_same<T, uint8_t>::value)
704 {
705 res = quantize_qasymm8(res, iq_info);
706 }
707 else
708 {
709 res = quantize_qasymm8_signed(res, iq_info);
710 }
711
712 *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100713 break;
714 }
715 default:
716 {
717 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
718 carry_res = wrapper::vadd(carry_res, vec_res_value3);
719 carry_res = wrapper::vadd(carry_res, vec_res_value4);
720
721 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
722 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000723 auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
Usama Arifa4a08ad2019-05-20 12:38:33 +0100724
725 if(op == ReductionOperation::MEAN_SUM)
726 {
Manuel Bottini77b88592020-05-04 18:42:32 +0100727 res /= static_cast<int32_t>(in_info.dimension(0));
Usama Arifa4a08ad2019-05-20 12:38:33 +0100728 }
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000729 else
730 {
731 // Subtract accumulated offsets
732 res -= (in_info.dimension(0) - 1) * iq_info.offset;
733 }
Luca Foschianiee939fb2020-01-28 10:38:07 +0000734 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100735 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000736 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100737 }
738};
739
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000740template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100741struct RedOpYZW
742{
743 /** NEON vector tag type. */
744 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000745 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100746
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000747 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100748 {
749 ARM_COMPUTE_UNUSED(out_slice);
750
giuros01154bc1c2019-03-26 17:44:40 +0000751 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100752 {
Michalis Spyrou728d6f72019-01-16 13:57:58 +0000753 neon_vector vec_res_value = { 0 };
Usama Arifa4a08ad2019-05-20 12:38:33 +0100754 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000755 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100756 case ReductionOperation::ARG_IDX_MAX:
757 case ReductionOperation::ARG_IDX_MIN:
758 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100759 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100760 {
761 vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
762 break;
763 }
764 case ReductionOperation::PROD:
765 {
766 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
767 break;
768 }
769 default:
770 {
771 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
772 break;
773 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000774 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000775 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000776
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100777 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
778 {
Pablo Tello07958782020-01-09 14:43:36 +0000779 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.strides_in_bytes()[axis] * dim);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100780 const auto vec_elements = wrapper::vloadq(in_ptr);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000781 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100782 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000783 case ReductionOperation::SUM:
784 case ReductionOperation::MEAN_SUM:
785 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
786 break;
787 case ReductionOperation::SUM_SQUARE:
788 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
789 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000790 case ReductionOperation::PROD:
791 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
792 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000793 case ReductionOperation::ARG_IDX_MIN:
794 {
795 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
796 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
797 vec_res_value = temp_vec_res_value;
798 break;
799 }
800 case ReductionOperation::ARG_IDX_MAX:
801 {
802 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
803 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
804 vec_res_value = temp_vec_res_value;
805 break;
806 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100807 case ReductionOperation::MIN:
808 {
809 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
810 break;
811 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100812 case ReductionOperation::MAX:
813 {
814 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
815 break;
816 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000817 default:
818 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100819 }
820 }
821
822 if(op == ReductionOperation::MEAN_SUM)
823 {
824 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000825 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100826 }
827
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000828 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
829 {
830 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100831#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
832 if(std::is_same<T, float16_t>::value)
833 {
834 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
835 }
836#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000837 }
838 else
839 {
840 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
841 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100842 },
843 input, output);
844 }
845};
846
giuros01154bc1c2019-03-26 17:44:40 +0000847template <typename T, int S, int axis, ReductionOperation op>
848struct RedOpYZW_complex
849{
850 /** NEON vector tag type. */
851 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
852 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
853
854 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
855 {
856 ARM_COMPUTE_UNUSED(out_slice);
857 ARM_COMPUTE_ERROR_ON(axis != 2);
858
859 const size_t stride_z = in_info.strides_in_bytes()[axis];
860
861 execute_window_loop(in_slice, [&](const Coordinates &)
862 {
863 neon_vector vec_res_value_0 = { 0 };
864 neon_vector vec_res_value_1 = { 0 };
865
866 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
867 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
868
869 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
870 {
871 T *in_ptr_0;
872 T *in_ptr_1;
873 switch(axis)
874 {
875 case 2:
876 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
877 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
878 break;
879 default:
880 ARM_COMPUTE_ERROR("Not supported");
881 }
882 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
883 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
884
885 switch(op)
886 {
887 case ReductionOperation::SUM:
888 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
889 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
890 break;
891 default:
892 ARM_COMPUTE_ERROR("Not supported");
893 }
894 }
895
896 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
897 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
898
899 },
900 input, output);
901 }
902};
903
Luca Foschianiee939fb2020-01-28 10:38:07 +0000904template <typename T>
905struct RedOpYZW_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100906{
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000907 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100908 {
909 ARM_COMPUTE_UNUSED(out_slice);
910
Luca Foschianiee939fb2020-01-28 10:38:07 +0000911 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
912
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100913 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
914
giuros01154bc1c2019-03-26 17:44:40 +0000915 execute_window_loop(in_slice, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100916 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000917 uint32x4x4_t vec_res_idx{ { 0 } };
Luca Foschianiee939fb2020-01-28 10:38:07 +0000918 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
919 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
920 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
921 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000922
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000923 auto vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
924 auto vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
925 auto vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
926 auto vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000927
Luca Foschianiee939fb2020-01-28 10:38:07 +0000928 auto vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000929
930 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100931 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000932 const T *in_ptr = reinterpret_cast<T *>(input.ptr()) + in_info.strides_in_bytes()[axis] * index_dim;
933 const auto vec_elements = wrapper::vloadq(in_ptr);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000934 switch(op)
935 {
936 case ReductionOperation::SUM:
937 case ReductionOperation::MEAN_SUM:
938 {
939 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
940 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100941
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000942 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
943 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
944 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
945 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100946
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000947 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
948 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
949 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
950 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
951 break;
952 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000953 case ReductionOperation::PROD:
954 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000955 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
956 const auto scale32x4f_4 = wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000957
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000958 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
959 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000960
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000961 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
962 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
963 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
964 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000965
Luca Foschianiee939fb2020-01-28 10:38:07 +0000966 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
967 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
968 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
969 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000970
971 //de-quantize vec_elements
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000972 temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
973 temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
974 temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
975 temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000976
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000977 vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
978 vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
979 vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
980 vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000981 break;
982 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000983 case ReductionOperation::ARG_IDX_MIN:
984 {
985 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000986 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000987 vec_res_value = temp_vec_res_value;
988 break;
989 }
990 case ReductionOperation::ARG_IDX_MAX:
991 {
992 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Luca Foschianiee939fb2020-01-28 10:38:07 +0000993 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000994 vec_res_value = temp_vec_res_value;
995 break;
996 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100997 case ReductionOperation::MIN:
998 {
999 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1000 break;
1001 }
Usama Arif28f0dd92019-05-20 13:44:34 +01001002 case ReductionOperation::MAX:
1003 {
1004 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1005 break;
1006 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001007 default:
1008 ARM_COMPUTE_ERROR("Not supported");
1009 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001010 }
1011
1012 if(op == ReductionOperation::MEAN_SUM)
1013 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001014 const auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<float>(in_info.dimension(axis)), wrapper::traits::vector_128_tag{}));
Luca Foschianiee939fb2020-01-28 10:38:07 +00001015 vec_res_value1_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value1), vec_width_inv);
1016 vec_res_value2_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value2), vec_width_inv);
1017 vec_res_value3_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value3), vec_width_inv);
1018 vec_res_value4_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value4), vec_width_inv);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001019
Luca Foschianiee939fb2020-01-28 10:38:07 +00001020 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1021 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1022 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1023 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001024 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001025 else if(op == ReductionOperation::PROD)
1026 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001027 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001028 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001029
1030 //re-quantize
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001031 vec_res_value1_f = wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1032 vec_res_value2_f = wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1033 vec_res_value3_f = wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1034 vec_res_value4_f = wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001035
Luca Foschianiee939fb2020-01-28 10:38:07 +00001036 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1037 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1038 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1039 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
Manuel Bottini1d4f3852019-01-14 15:14:43 +00001040 }
1041
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001042 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1043 {
1044 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
1045 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vec_res_idx.val[1]);
1046 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
1047 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
1048 }
Nate Craun14252f92020-02-26 18:36:17 -05001049 else if(op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Usama Arifa4a08ad2019-05-20 12:38:33 +01001050 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001051 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value);
Usama Arifa4a08ad2019-05-20 12:38:33 +01001052 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001053 else
1054 {
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001055 if(op == ReductionOperation::SUM)
1056 {
1057 // Subtract offsets
1058 auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1059
Luca Foschianiee939fb2020-01-28 10:38:07 +00001060 auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1061 auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1062 auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1063 auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001064
1065 vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1066 vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1067 vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1068 vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1069
1070 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1071 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
Luca Foschianiee939fb2020-01-28 10:38:07 +00001072
1073 combine_and_store<T>(temp16x8t_1, temp16x8t_2, output);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001074 }
1075 else
1076 {
1077 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1078 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1079 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
Luca Foschianiee939fb2020-01-28 10:38:07 +00001080
1081 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), res);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001082 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001083 }
1084
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001085 },
1086 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001087 }
1088};
1089
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001090void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001091{
giuros01154bc1c2019-03-26 17:44:40 +00001092 const bool is_complex = (input->info()->num_channels() == 2);
1093
1094 if(is_complex)
1095 {
1096 switch(axis)
1097 {
1098 case 2:
1099 switch(input->info()->data_type())
1100 {
1101 case DataType::F32:
1102 switch(op)
1103 {
1104 case ReductionOperation::SUM:
1105 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1106 default:
1107 ARM_COMPUTE_ERROR("Not supported");
1108 }
1109 default:
1110 ARM_COMPUTE_ERROR("Not supported");
1111 }
1112 default:
1113 ARM_COMPUTE_ERROR("Not supported");
1114 }
1115 }
1116
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001117 switch(axis)
1118 {
1119 case 0:
1120 switch(input->info()->data_type())
1121 {
1122 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001123 return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
1124 case DataType::QASYMM8_SIGNED:
1125 return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001126#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1127 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001128 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001129#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1130 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001131 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001132 case DataType::S32:
1133 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001134 default:
1135 ARM_COMPUTE_ERROR("Not supported");
1136 }
1137 case 1:
1138 switch(input->info()->data_type())
1139 {
1140 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001141 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1142 case DataType::QASYMM8_SIGNED:
1143 return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001144#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1145 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001146 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001147#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1148 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001149 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001150 case DataType::S32:
1151 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001152 default:
1153 ARM_COMPUTE_ERROR("Not supported");
1154 }
1155 case 2:
1156 switch(input->info()->data_type())
1157 {
1158 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001159 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1160 case DataType::QASYMM8_SIGNED:
1161 return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001162#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1163 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001164 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001165#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1166 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001167 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001168 case DataType::S32:
1169 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001170 default:
1171 ARM_COMPUTE_ERROR("Not supported");
1172 }
1173 case 3:
1174 switch(input->info()->data_type())
1175 {
1176 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001177 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1178 case DataType::QASYMM8_SIGNED:
1179 return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001180#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1181 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001182 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001183#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1184 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001185 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001186 case DataType::S32:
1187 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001188 default:
1189 ARM_COMPUTE_ERROR("Not supported");
1190 }
1191 default:
1192 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1193 }
1194}
John Richardson73d4aef2018-05-08 14:34:33 +01001195
1196Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1197{
1198 ARM_COMPUTE_UNUSED(op);
1199
1200 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001201 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001202
1203 if(input->num_channels() == 1)
1204 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001205 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001206 }
1207 else
1208 {
1209 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1210 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1211 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1212 }
John Richardson73d4aef2018-05-08 14:34:33 +01001213
1214 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001215 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001216
1217 if(output->total_size() != 0)
1218 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001219 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1220 if(!is_arg_min_max)
1221 {
1222 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001223 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001224 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001225 }
1226 else
1227 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001228 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001229 }
John Richardson73d4aef2018-05-08 14:34:33 +01001230
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001231 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001232 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1233 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1234 }
1235
1236 return Status{};
1237}
1238
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001239std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001240{
1241 // Calculate output shape and set if empty
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001242 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001243
1244 // Output auto initialization if not yet initialized
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001245 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
Sang-Hoon Parkeaa01ab2019-11-11 17:33:28 +00001246 DataType output_data_type = is_arg_min_max ? DataType::S32 : input->data_type();
giuros01154bc1c2019-03-26 17:44:40 +00001247 auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001248
1249 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
1250
1251 // Configure kernel window
1252 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
1253 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
1254 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
1255
1256 bool window_changed = update_window_and_padding(win, input_access, output_access);
1257 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1258
1259 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1260
1261 return std::make_tuple(err, win);
1262}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001263} // namespace
1264
1265NEReductionOperationKernel::NEReductionOperationKernel()
1266 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
1267{
1268}
1269
1270BorderSize NEReductionOperationKernel::border_size() const
1271{
1272 return _border_size;
1273}
1274
1275void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1276{
1277 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001278
John Richardson73d4aef2018-05-08 14:34:33 +01001279 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001280
1281 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
1282
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001283 _input = input;
1284 _output = output;
1285 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
1286 _op = op;
1287 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001288
1289 // Configure kernel window
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001290 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001291
John Richardson73d4aef2018-05-08 14:34:33 +01001292 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001293
John Richardson73d4aef2018-05-08 14:34:33 +01001294 INEKernel::configure(std::get<1>(win_config));
1295}
1296
1297Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1298{
1299 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001300 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
John Richardson73d4aef2018-05-08 14:34:33 +01001301
1302 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001303}
1304
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001305void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001306{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001307 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001308 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1309 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1310
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001311 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001312}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001313} // namespace arm_compute