blob: 2bbd9452f229945d003fe858cf500d712427ebef [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Sheri Zhangac6499a2021-02-10 15:32:38 +00002 * Copyright (c) 2017-2021 Arm Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEReductionOperationKernel.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010025
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010028#include "arm_compute/core/ITensor.h"
John Richardson73d4aef2018-05-08 14:34:33 +010029#include "arm_compute/core/TensorInfo.h"
Luca Foschianiee939fb2020-01-28 10:38:07 +000030#include "arm_compute/core/Utils.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010031#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000032#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/CPP/Validate.h"
Michalis Spyrouebcebf12020-10-21 00:04:14 +010034#include "src/core/NEON/INEKernel.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010035#include "src/core/NEON/NEMath.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010036#include "src/core/helpers/AutoConfiguration.h"
37#include "src/core/helpers/WindowHelpers.h"
38#include "support/SaturateCast.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010039
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010040#include "src/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010041#include <arm_neon.h>
42
Michalis Spyroubcf8a962018-10-12 10:51:31 +010043namespace arm_compute
44{
Georgios Pinitasd9769582017-08-03 10:19:40 +010045namespace
46{
Luca Foschianiee939fb2020-01-28 10:38:07 +000047// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
48template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +010049void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
Luca Foschianiee939fb2020-01-28 10:38:07 +000050{
51 if(std::is_same<T, uint8_t>::value)
52 {
53 auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010054 wrapper::vstore(output.ptr() + offset, res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000055 }
56 else
57 {
58 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010059 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000060 }
61}
62
Michalis Spyroub9626ab2019-05-13 17:41:01 +010063template <typename T>
64uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000065{
66 uint32x4_t mask{ 0 };
67 if(op == ReductionOperation::ARG_IDX_MIN)
68 {
69 mask = wrapper::vcgt(b, a);
70 }
71 else
72 {
73 mask = wrapper::vclt(b, a);
74 }
75
76 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
77 if(axis != 0)
78 {
79 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
80 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000081 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000082
83 return res;
84}
85
Luca Foschianiee939fb2020-01-28 10:38:07 +000086template <typename T>
87uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000088{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000089 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000090 uint8x16_t mask_u8{ 0 };
91 if(op == ReductionOperation::ARG_IDX_MIN)
92 {
93 mask_u8 = wrapper::vcgt(b, a);
94 }
95 else
96 {
97 mask_u8 = wrapper::vclt(b, a);
98 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000099 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
100 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
101 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
102 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
103 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
104 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
105
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000106 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
107 { idx + 4, idx + 5, idx + 6, idx + 7 },
108 { idx + 8, idx + 9, idx + 10, idx + 11 },
109 { idx + 12, idx + 13, idx + 14, idx + 15 }
110 }
111 };
112 if(axis != 0)
113 {
114 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
115 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
116 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
117 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
118 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000119 uint32x4x4_t res =
120 {
121 {
122 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
123 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
124 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
125 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
126 }
127 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000128
129 return res;
130}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100131
132// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000133template <typename T>
134inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
135 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
136 calculate_min(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100137{
138 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
139 return wrapper::vpmin(pmin, pmin);
140}
141
Luca Foschianiee939fb2020-01-28 10:38:07 +0000142// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
143template <typename T>
144inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
145 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
146 calculate_min(T in)
147{
148 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
149 pmin = wrapper::vpmin(pmin, pmin);
150 pmin = wrapper::vpmin(pmin, pmin);
151 return wrapper::vpmin(pmin, pmin);
152}
153
Usama Arifa4a08ad2019-05-20 12:38:33 +0100154// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000155template <typename T>
156inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
157 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
158 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100159{
160 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
161 return wrapper::vpmax(pmax, pmax);
162}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100163
164// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000165template <typename T>
166inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
167 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
168 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100169{
170 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000171 pmax = wrapper::vpmax(pmax, pmax);
172 pmax = wrapper::vpmax(pmax, pmax);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100173 return wrapper::vpmax(pmax, pmax);
174}
175
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100176template <typename T>
177uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000178{
179 uint32x4_t res_idx_mask{ 0 };
180 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
181
182 if(op == ReductionOperation::ARG_IDX_MIN)
183 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100184 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000185 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
186 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
187 }
188 else
189 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100190 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100191 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000192 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
193 }
194
195 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
196 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
197 pmin = wrapper::vpmin(pmin, pmin);
198 uint32_t res = wrapper::vgetlane(pmin, 0);
199
200 return (res - 0xFFFFFFFF);
201}
202
Luca Foschianiee939fb2020-01-28 10:38:07 +0000203template <typename T>
204uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000205{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000206 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000207 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
208 uint8x16_t mask_u8{ 0 };
209 if(op == ReductionOperation::ARG_IDX_MIN)
210 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100211 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000212 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
213 }
214 else
215 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100216 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000217 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
218 }
219
220 // Widen vectors
221 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
222 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
223 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
224 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
225 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
226 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
227 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
228 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
229 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
230 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
231 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
232 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
233 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
234 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
235
236 uint32_t res = 0xFFFFFFFF;
237 int iter = 0;
238 do
239 {
240 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
241 pmin = wrapper::vpmin(pmin, pmin);
242 res = std::min(wrapper::vgetlane(pmin, 0), res);
243 iter++;
244 }
245 while(iter < 4);
246
247 return (res - 0xFFFFFFFF);
248}
Luca Foschianiee939fb2020-01-28 10:38:07 +0000249
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000250#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100251template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000252uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
253{
254 uint32x4x2_t mask{ 0 };
255 uint16x8_t mask_u16{ 0 };
256 if(op == ReductionOperation::ARG_IDX_MIN)
257 {
258 mask_u16 = wrapper::vcgt(b, a);
259 }
260 else
261 {
262 mask_u16 = wrapper::vclt(b, a);
263 }
264 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
265 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
266 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
267 { idx + 4, idx + 5, idx + 6, idx + 7 }
268 }
269 };
270 if(axis != 0)
271 {
272 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
273 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
274 }
275 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
276 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
277 0, 0
278 };
279
280 return res;
281}
282
Usama Arifa4a08ad2019-05-20 12:38:33 +0100283// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
284inline float16x4_t calculate_min(float16x8_t in)
285{
286 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
287 pmin = wrapper::vpmin(pmin, pmin);
288 return wrapper::vpmin(pmin, pmin);
289}
290// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
291inline float16x4_t calculate_max(float16x8_t in)
292{
293 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
294 pmax = wrapper::vpmax(pmax, pmax);
295 return wrapper::vpmax(pmax, pmax);
296}
297
Usama Arif0a5a57a2019-05-23 14:20:33 +0100298template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000299uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
300{
301 uint32x4x2_t res_idx_mask{ 0 };
302 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
303 uint16x8_t mask_u16;
304 if(op == ReductionOperation::ARG_IDX_MIN)
305 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100306 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000307 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
308 }
309 else
310 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100311 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000312 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
313 }
314
315 // Widen vectors
316 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
317 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
318 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
319 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
320 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
321 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
322
323 uint32_t res = 0xFFFFFFFF;
324 int iter = 0;
325 do
326 {
327 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
328 pmin = wrapper::vpmin(pmin, pmin);
329 res = std::min(wrapper::vgetlane(pmin, 0), res);
330 iter++;
331 }
332 while(iter < 2);
333
334 return (res - 0xFFFFFFFF);
335}
336#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
337
Georgios Pinitasd9769582017-08-03 10:19:40 +0100338template <class F>
339class Reducer
340{
341public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000342 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100343 {
344 // Set out window
345 Window out_window(window);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100346 out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100347
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100348 f(window, out_window, input, output, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100349 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000350 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100351 {
352 // Set in window
353 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000354 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100355
356 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000357 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100359 f(in_window, out_window, input, output, 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000361 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100362 {
363 // Set in window
364 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000365 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100366
367 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000368 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100369
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100370 f(in_window, out_window, input, output, 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000372 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100373 {
374 // Set in/out window
375 Window in_window(window);
376 Window out_window(window);
377
378 in_window.set(3, Window::Dimension(0, 1, 1));
379 out_window.set(3, Window::Dimension(0, 1, 1));
380
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100381 f(in_window, out_window, input, output, 3, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100382 }
383};
384
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000385template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100386struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100387{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +0000388 /** SIMD vector tag type. */
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100389 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
390
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100391 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100392 {
Georgios Pinitas412b7892020-11-11 21:05:24 +0000393 const TensorInfo in_info = *(in->info());
394 const int window_step_x = 16 / sizeof(T);
395 const auto window_start_x = static_cast<int>(in_window.x().start());
396 const auto window_end_x = static_cast<int>(in_window.x().end());
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100397
Georgios Pinitas412b7892020-11-11 21:05:24 +0000398 Window in_win_no_pad = in_window;
399 in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100400
Georgios Pinitas412b7892020-11-11 21:05:24 +0000401 Iterator input(in, in_win_no_pad);
402 Iterator output(out, out_window);
403
404 execute_window_loop(in_win_no_pad, [&](const Coordinates &)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000405 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100406 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
Georgios Pinitasd9769582017-08-03 10:19:40 +0100407
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100408 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100409 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000410 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100411 case ReductionOperation::ARG_IDX_MAX:
412 case ReductionOperation::ARG_IDX_MIN:
413 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100414 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100415 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100416 init_res_value = static_cast<T>(*input_ptr);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100417 break;
418 }
419 case ReductionOperation::PROD:
420 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100421 init_res_value = static_cast<T>(1.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100422 break;
423 }
424 default:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100425 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000426 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100427 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000428 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000429
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100430 // Compute window_step_x elements per iteration
431 int x = window_start_x;
432 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100433 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100434 const auto vec_elements = wrapper::vloadq(input_ptr + x);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000435 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100436 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000437 case ReductionOperation::SUM_SQUARE:
438 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
439 break;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100440 case ReductionOperation::MEAN_SUM:
441 case ReductionOperation::SUM:
442 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
443 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000444 case ReductionOperation::PROD:
445 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
446 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000447 case ReductionOperation::ARG_IDX_MIN:
448 {
449 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100450 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000451 vec_res_value = temp_vec_res_value;
452 break;
453 }
454 case ReductionOperation::ARG_IDX_MAX:
455 {
456 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100457 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000458 vec_res_value = temp_vec_res_value;
459 break;
460 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100461 case ReductionOperation::MIN:
462 {
463 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
464 break;
465 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100466 case ReductionOperation::MAX:
467 {
468 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
469 break;
470 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000471 default:
472 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100473 }
474 }
475
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100476 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100477 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100478 case ReductionOperation::SUM:
479 case ReductionOperation::MEAN_SUM:
480 case ReductionOperation::SUM_SQUARE:
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100481 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100482 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
483 for(int i = 0; i < S / 4; ++i)
484 {
485 carry_res = wrapper::vpadd(carry_res, carry_res);
486 }
487 auto res = wrapper::vgetlane(carry_res, 0);
488
489 if(op == ReductionOperation::SUM_SQUARE)
490 {
491 // Compute left-over elements
492 for(; x < window_end_x; ++x)
493 {
494 res += (*(input_ptr + x)) * (*(input_ptr + x));
495 }
496 }
497 else
498 {
499 // Compute left-over elements
500 for(; x < window_end_x; ++x)
501 {
502 res += *(input_ptr + x);
503 }
504 }
505
506 if(op == ReductionOperation::MEAN_SUM)
507 {
508 res /= in_info.dimension(0);
509 }
510
511 *(reinterpret_cast<T *>(output.ptr())) = res;
512 break;
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100513 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100514 case ReductionOperation::PROD:
giuros01154bc1c2019-03-26 17:44:40 +0000515 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100516 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
517 T res = 1;
518 for(int i = 0; i < S / 2; ++i)
519 {
520 res *= wrapper::vgetlane(carry_res, i);
521 }
giuros01154bc1c2019-03-26 17:44:40 +0000522
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100523 // Compute left-over elements
524 for(; x < window_end_x; ++x)
525 {
526 res *= *(input_ptr + x);
527 }
528
529 *(reinterpret_cast<T *>(output.ptr())) = res;
530 break;
531 }
532 case ReductionOperation::ARG_IDX_MIN:
giuros01154bc1c2019-03-26 17:44:40 +0000533 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100534 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
535 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
536
537 // Compute left-over elements
538 for(; x < window_end_x; ++x)
539 {
540 if(*(input_ptr + x) < res)
541 {
542 idx = x;
543 res = *(input_ptr + x);
544 }
545 }
546 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
547 break;
giuros01154bc1c2019-03-26 17:44:40 +0000548 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100549 case ReductionOperation::ARG_IDX_MAX:
550 {
551 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
552 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
553
554 // Compute left-over elements
555 for(; x < window_end_x; ++x)
556 {
557 if(*(input_ptr + x) > res)
558 {
559 idx = x;
560 res = *(input_ptr + x);
561 }
562 }
563 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
564 break;
565 }
566 case ReductionOperation::MIN:
567 {
568 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
569
570 // Compute left-over elements
571 for(; x < window_end_x; ++x)
572 {
573 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
574 }
575 *(reinterpret_cast<T *>(output.ptr())) = res;
576 break;
577 }
578 case ReductionOperation::MAX:
579 {
580 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
581
582 // Compute left-over elements
583 for(; x < window_end_x; ++x)
584 {
585 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
586 }
587 *(reinterpret_cast<T *>(output.ptr())) = res;
588 break;
589 }
590 default:
591 ARM_COMPUTE_ERROR("Not supported");
giuros01154bc1c2019-03-26 17:44:40 +0000592 }
giuros01154bc1c2019-03-26 17:44:40 +0000593 },
594 input, output);
595 }
596};
597
Luca Foschianiee939fb2020-01-28 10:38:07 +0000598template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100599struct RedOpX_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100600{
Pablo Marquez Telloe81825b2021-03-23 15:47:47 +0000601 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
602
603 using vtype = decltype(wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}));
604 using stype = decltype(vdupq_n_f32(static_cast<float>(1.f)));
605 using rtype = typename wrapper::traits::neon_vector<T, 16>::type;
606
607 void vprocess(int x, const T *ptr, const ReductionOperation op, const UniformQuantizationInfo &iq_info,
608 vtype &vec_res_value1, vtype &vec_res_value2, vtype &vec_res_value3, vtype &vec_res_value4,
609
610 stype &vec_res_value1_f,
611 stype &vec_res_value2_f,
612 stype &vec_res_value3_f,
613 stype &vec_res_value4_f,
614 uint32x4x4_t &vec_res_idx,
615
616 rtype &vec_res_value)
617
618 {
619 const auto vec_elements = wrapper::vloadq(ptr);
620
621 switch(op)
622 {
623 case ReductionOperation::SUM:
624 case ReductionOperation::MEAN_SUM:
625 {
626 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
627 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
628
629 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
630 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
631 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
632 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
633
634 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
635 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
636 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
637 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
638 break;
639 }
640 case ReductionOperation::PROD:
641 {
642 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
643 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
644
645 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
646 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
647
648 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
649 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
650 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
651 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
652
653 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
654 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
655 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
656 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
657
658 //de-quantize vec_elements
659 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
660 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
661 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
662 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
663
664 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
665 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
666 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
667 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
668 break;
669 }
670 case ReductionOperation::ARG_IDX_MIN:
671 {
672 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
673 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
674 vec_res_value = temp_vec_res_value;
675 break;
676 }
677 case ReductionOperation::ARG_IDX_MAX:
678 {
679 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
680 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
681 vec_res_value = temp_vec_res_value;
682 break;
683 }
684 case ReductionOperation::MIN:
685 {
686 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
687 break;
688 }
689 case ReductionOperation::MAX:
690 {
691 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
692 break;
693 }
694 default:
695 ARM_COMPUTE_ERROR("Not supported");
696 }
697 }
698 void leftover_process(int x, int window_end_x, const T *input_ptr, Iterator output, const ReductionOperation op, const TensorInfo in_info,
699 const UniformQuantizationInfo &iq_info,
700 vtype &vec_res_value1, vtype &vec_res_value2, vtype &vec_res_value3, vtype &vec_res_value4,
701
702 stype &vec_res_value1_f,
703 stype &vec_res_value2_f,
704 stype &vec_res_value3_f,
705 stype &vec_res_value4_f,
706 uint32x4x4_t &vec_res_idx,
707
708 rtype &vec_res_value)
709
710 {
711 switch(op)
712 {
713 case ReductionOperation::ARG_IDX_MIN:
714 {
715 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
716 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
717
718 // Compute left-over elements
719 for(; x < window_end_x; ++x)
720 {
721 if(*(input_ptr + x) < res)
722 {
723 idx = x;
724 res = *(input_ptr + x);
725 }
726 }
727 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
728 break;
729 }
730 case ReductionOperation::ARG_IDX_MAX:
731 {
732 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
733 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
734
735 // Compute left-over elements
736 for(; x < window_end_x; ++x)
737 {
738 if(*(input_ptr + x) > res)
739 {
740 idx = x;
741 res = *(input_ptr + x);
742 }
743 }
744 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
745 break;
746 }
747 case ReductionOperation::MIN:
748 {
749 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
750
751 // Compute left-over elements
752 for(; x < window_end_x; ++x)
753 {
754 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
755 }
756 *(reinterpret_cast<T *>(output.ptr())) = res;
757 break;
758 }
759 case ReductionOperation::MAX:
760 {
761 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
762
763 // Compute left-over elements
764 for(; x < window_end_x; ++x)
765 {
766 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
767 }
768 *(reinterpret_cast<T *>(output.ptr())) = res;
769 break;
770 }
771 case ReductionOperation::PROD:
772 {
773 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
774 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
775 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
776
777 float res = wrapper::vgetlane(carry_res, 0);
778 res *= wrapper::vgetlane(carry_res, 1);
779 res *= wrapper::vgetlane(carry_res, 2);
780 res *= wrapper::vgetlane(carry_res, 3);
781
782 // Compute left-over elements
783 for(; x < window_end_x; ++x)
784 {
785 //de-quantize input
786 if(std::is_same<T, uint8_t>::value)
787 {
788 res *= dequantize_qasymm8(*(input_ptr + x), iq_info);
789 }
790 else
791 {
792 res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info);
793 }
794 }
795
796 //re-quantize result
797 if(std::is_same<T, uint8_t>::value)
798 {
799 res = quantize_qasymm8(res, iq_info);
800 }
801 else
802 {
803 res = quantize_qasymm8_signed(res, iq_info);
804 }
805
806 *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
807 break;
808 }
809 case ReductionOperation::SUM:
810 case ReductionOperation::MEAN_SUM:
811 {
812 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
813 carry_res = wrapper::vadd(carry_res, vec_res_value3);
814 carry_res = wrapper::vadd(carry_res, vec_res_value4);
815
816 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
817 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
818 auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
819
820 // Compute left-over elements
821 for(; x < window_end_x; ++x)
822 {
823 res += *(input_ptr + x);
824 }
825
826 if(op == ReductionOperation::MEAN_SUM)
827 {
828 res /= static_cast<int32_t>(in_info.dimension(0));
829 }
830 else
831 {
832 // Subtract accumulated offsets
833 res -= (in_info.dimension(0) - 1) * iq_info.offset;
834 }
835 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
836 break;
837 }
838 default:
839 ARM_COMPUTE_ERROR("Not supported");
840 }
841 }
842
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100843 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100844 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000845 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
846
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100847 const TensorInfo in_info = *(in->info());
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100848 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
849
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100850 const int window_step_x = 16 / sizeof(T);
851 const auto window_start_x = static_cast<int>(in_window.x().start());
852 const auto window_end_x = static_cast<int>(in_window.x().end());
853
Georgios Pinitas412b7892020-11-11 21:05:24 +0000854 Window in_win_no_pad = in_window;
855 in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
856
857 Iterator input(in, in_win_no_pad);
858 Iterator output(out, out_window);
859
860 execute_window_loop(in_win_no_pad, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100861 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100862 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000863
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100864 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
865 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
866 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
867 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000868
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100869 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
870 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
871 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
872 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000873
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100874 typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = { 0 };
875
876 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100877 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100878 vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{});
879 }
880
881 uint32x4x4_t vec_res_idx{ { 0 } };
882 // Compute window_step_x elements per iteration
883 int x = window_start_x;
884 for(; x <= (window_end_x - window_step_x); x += window_step_x)
885 {
Pablo Marquez Telloe81825b2021-03-23 15:47:47 +0000886 vprocess(x, input_ptr + x, op, iq_info,
887 vec_res_value1, vec_res_value2, vec_res_value3, vec_res_value4,
888 vec_res_value1_f,
889 vec_res_value2_f,
890 vec_res_value3_f,
891 vec_res_value4_f,
892 vec_res_idx, vec_res_value);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100893 }
Pablo Marquez Telloe81825b2021-03-23 15:47:47 +0000894 leftover_process(x, window_end_x, input_ptr, output, op, in_info, iq_info,
895 vec_res_value1, vec_res_value2, vec_res_value3, vec_res_value4,
896 vec_res_value1_f,
897 vec_res_value2_f,
898 vec_res_value3_f,
899 vec_res_value4_f,
900 vec_res_idx, vec_res_value);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100901
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100902 },
903 input, output);
904 }
905};
906
907template <typename T, int S>
908struct RedOpYZW
909{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +0000910 /** SIMD vector tag type. */
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100911 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
912 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
913
914 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
915 {
Georgios Pinitas412b7892020-11-11 21:05:24 +0000916 const TensorInfo in_info = *(in->info());
917 const int window_step_x = 16 / sizeof(T);
918 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
919 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
920 // As it split over x-axis, need to set the correct spiltted window start and end.
921 const auto window_start_x = static_cast<int>(0);
922 const auto window_end_x = static_cast<int>(in_window.shape().x());
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100923
Georgios Pinitas412b7892020-11-11 21:05:24 +0000924 Window in_win_no_pad = in_window;
925 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
926 Window out_win_no_pad = out_window;
927 out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100928
Georgios Pinitas412b7892020-11-11 21:05:24 +0000929 Iterator input(in, in_win_no_pad);
930 Iterator output(out, out_win_no_pad);
931
932 execute_window_loop(in_win_no_pad, [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100933 {
934 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
935
936 // Compute window_step_x elements per iteration
937 int x = window_start_x;
938 for(; x <= (window_end_x - window_step_x); x += window_step_x)
939 {
940 neon_vector vec_res_value = { 0 };
941 switch(op)
942 {
943 case ReductionOperation::ARG_IDX_MAX:
944 case ReductionOperation::ARG_IDX_MIN:
945 case ReductionOperation::MIN:
946 case ReductionOperation::MAX:
947 {
948 vec_res_value = wrapper::vloadq(input_ptr + x);
949 break;
950 }
951 case ReductionOperation::PROD:
952 {
953 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
954 break;
955 }
956 default:
957 {
958 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
959 break;
960 }
961 }
962 uint32x4x4_t vec_res_idx{ { 0 } };
963
964 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
965 {
966 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
967 const auto vec_elements = wrapper::vloadq(in_ptr);
968 switch(op)
969 {
970 case ReductionOperation::SUM:
971 case ReductionOperation::MEAN_SUM:
972 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
973 break;
974 case ReductionOperation::SUM_SQUARE:
975 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
976 break;
977 case ReductionOperation::PROD:
978 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
979 break;
980 case ReductionOperation::ARG_IDX_MIN:
981 {
982 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
983 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
984 vec_res_value = temp_vec_res_value;
985 break;
986 }
987 case ReductionOperation::ARG_IDX_MAX:
988 {
989 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
990 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
991 vec_res_value = temp_vec_res_value;
992 break;
993 }
994 case ReductionOperation::MIN:
995 {
996 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
997 break;
998 }
999 case ReductionOperation::MAX:
1000 {
1001 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1002 break;
1003 }
1004 default:
1005 ARM_COMPUTE_ERROR("Not supported");
1006 }
1007 }
1008
1009 if(op == ReductionOperation::MEAN_SUM)
1010 {
1011 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
1012 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
1013 }
1014
1015 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1016 {
1017 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
1018#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1019 if(std::is_same<T, float16_t>::value)
1020 {
1021 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
1022 }
1023#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001024 }
1025 else
1026 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001027 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001028 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001029 }
1030
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001031 // Compute left-over elements
1032 for(; x < window_end_x; ++x)
1033 {
1034 auto res_value = 0.f;
1035 switch(op)
1036 {
1037 case ReductionOperation::ARG_IDX_MAX:
1038 case ReductionOperation::ARG_IDX_MIN:
1039 case ReductionOperation::MIN:
1040 case ReductionOperation::MAX:
1041 {
1042 res_value = *(input_ptr + x);
1043 break;
1044 }
1045 case ReductionOperation::PROD:
1046 {
1047 res_value = static_cast<T>(1.f);
1048 break;
1049 }
1050 default:
1051 {
1052 res_value = static_cast<T>(0.f);
1053 break;
1054 }
1055 }
1056
1057 uint32_t res_idx = 0;
1058 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1059 {
1060 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
1061
1062 switch(op)
1063 {
1064 case ReductionOperation::SUM:
1065 case ReductionOperation::MEAN_SUM:
1066 res_value += *in_ptr;
1067 break;
1068 case ReductionOperation::SUM_SQUARE:
1069 res_value += *in_ptr * *in_ptr;
1070 break;
1071 case ReductionOperation::PROD:
1072 res_value *= *in_ptr;
1073 break;
1074 case ReductionOperation::ARG_IDX_MIN:
1075 {
1076 if(*in_ptr < res_value)
1077 {
1078 res_value = *in_ptr;
1079 res_idx = dim;
1080 }
1081 break;
1082 }
1083 case ReductionOperation::ARG_IDX_MAX:
1084 {
1085 if(*in_ptr > res_value)
1086 {
1087 res_value = *in_ptr;
1088 res_idx = dim;
1089 }
1090 break;
1091 }
1092 case ReductionOperation::MIN:
1093 {
1094 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1095 break;
1096 }
1097 case ReductionOperation::MAX:
1098 {
1099 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1100 break;
1101 }
1102 default:
1103 ARM_COMPUTE_ERROR("Not supported");
1104 }
1105 }
1106
1107 if(op == ReductionOperation::MEAN_SUM)
1108 {
1109 res_value /= in_info.dimension(axis);
1110 }
1111
1112 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1113 {
1114 *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
1115 }
1116 else
1117 {
1118 *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
1119 }
1120 }
1121 },
1122 input, output);
1123 }
1124};
1125
1126template <typename T, int S, int axis, ReductionOperation op>
1127struct RedOpYZW_complex
1128{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +00001129 /** SIMD vector tag type. */
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001130 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
1131 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
1132
1133 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
1134 {
1135 ARM_COMPUTE_ERROR_ON(axis != 2);
Georgios Pinitas412b7892020-11-11 21:05:24 +00001136 ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001137
Georgios Pinitas412b7892020-11-11 21:05:24 +00001138 const TensorInfo in_info = *(in->info());
1139 const size_t stride_z = in_info.strides_in_bytes()[axis];
1140 const int window_step_x = 16 / sizeof(T);
1141 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
1142 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
1143 // As it split over x-axis, need to set the correct spiltted window start and end.
1144 const auto window_start_x = static_cast<int>(0);
1145 const auto window_end_x = static_cast<int>(in_window.shape().x());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001146
Georgios Pinitas412b7892020-11-11 21:05:24 +00001147 Window in_win_no_pad = in_window;
1148 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1149 Window out_win_no_pad = out_window;
1150 out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001151
Georgios Pinitas412b7892020-11-11 21:05:24 +00001152 Iterator input(in, in_win_no_pad);
1153 Iterator output(out, out_win_no_pad);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001154
Georgios Pinitas412b7892020-11-11 21:05:24 +00001155 execute_window_loop(in_win_no_pad, [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001156 {
1157 // Compute window_step_x elements per iteration
1158 int x = window_start_x;
1159 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1160 {
1161 neon_vector vec_res_value_0 = { 0 };
1162 neon_vector vec_res_value_1 = { 0 };
1163
1164 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1165 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1166
1167 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1168 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1169 {
Georgios Pinitas412b7892020-11-11 21:05:24 +00001170 T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1171 T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
1172
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001173 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
1174 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
1175
Georgios Pinitas412b7892020-11-11 21:05:24 +00001176 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
1177 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001178 }
1179
1180 wrapper::vstore(out_ptr, vec_res_value_0);
1181 wrapper::vstore(out_ptr + 4, vec_res_value_1);
1182 }
1183
1184 // Compute left-over elements
1185 for(; x < window_end_x; ++x)
1186 {
1187 auto res_value_0 = 0.f;
1188 auto res_value_1 = 0.f;
1189
1190 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1191 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1192 {
Georgios Pinitas412b7892020-11-11 21:05:24 +00001193 T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1194 res_value_0 += *in_ptr;
1195 res_value_1 += *(in_ptr + 1);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001196 }
1197 *out_ptr = res_value_0;
1198 *(out_ptr + 1) = res_value_1;
1199 }
1200 },
1201 input, output);
1202 }
1203};
1204
1205template <typename T>
1206struct RedOpYZW_quantized
1207{
1208 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
1209 {
Georgios Pinitas412b7892020-11-11 21:05:24 +00001210 const TensorInfo in_info = *(in->info());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001211 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
Georgios Pinitas412b7892020-11-11 21:05:24 +00001212 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001213
Georgios Pinitas412b7892020-11-11 21:05:24 +00001214 const int window_step_x = 16 / sizeof(T);
1215 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
1216 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
1217 // As it split over x-axis, need to set the correct spiltted window start and end.
1218 const auto window_start_x = static_cast<int>(0);
1219 const auto window_end_x = static_cast<int>(in_window.shape().x());
1220
1221 Window in_win_no_pad = in_window;
1222 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1223 Window out_win_no_pad = out_window;
1224 out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
1225
1226 Iterator input(in, in_win_no_pad);
1227 Iterator output(out, out_win_no_pad);
1228
1229 execute_window_loop(in_win_no_pad, [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001230 {
1231 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
1232
1233 // Compute window_step_x elements per iteration
1234 int x = window_start_x;
1235 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1236 {
1237 uint32x4x4_t vec_res_idx{ { 0 } };
1238 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1239 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1240 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1241 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1242
1243 auto vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1244 auto vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1245 auto vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1246 auto vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1247
1248 auto vec_res_value = wrapper::vloadq(input_ptr + x);
1249
1250 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
1251 {
1252 const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim;
1253 const auto vec_elements = wrapper::vloadq(in_ptr);
1254 switch(op)
1255 {
1256 case ReductionOperation::SUM:
1257 case ReductionOperation::MEAN_SUM:
1258 {
1259 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1260 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1261
1262 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1263 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1264 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1265 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1266
1267 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
1268 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
1269 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
1270 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
1271 break;
1272 }
1273 case ReductionOperation::PROD:
1274 {
1275 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1276 const auto scale32x4f_4 = wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
1277
1278 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1279 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1280
1281 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1282 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1283 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1284 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1285
1286 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
1287 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
1288 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
1289 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
1290
1291 //de-quantize vec_elements
1292 temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
1293 temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
1294 temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
1295 temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
1296
1297 vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
1298 vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
1299 vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
1300 vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
1301 break;
1302 }
1303 case ReductionOperation::ARG_IDX_MIN:
1304 {
1305 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1306 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1307 vec_res_value = temp_vec_res_value;
1308 break;
1309 }
1310 case ReductionOperation::ARG_IDX_MAX:
1311 {
1312 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1313 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1314 vec_res_value = temp_vec_res_value;
1315 break;
1316 }
1317 case ReductionOperation::MIN:
1318 {
1319 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1320 break;
1321 }
1322 case ReductionOperation::MAX:
1323 {
1324 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1325 break;
1326 }
1327 default:
1328 ARM_COMPUTE_ERROR("Not supported");
1329 }
1330 }
1331
1332 switch(op)
1333 {
1334 case ReductionOperation::ARG_IDX_MIN:
1335 case ReductionOperation::ARG_IDX_MAX:
1336 {
1337 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]);
1338 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]);
1339 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]);
1340 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, vec_res_idx.val[3]);
1341 break;
1342 }
1343 case ReductionOperation::MIN:
1344 case ReductionOperation::MAX:
1345 {
1346 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value);
1347 break;
1348 }
1349 case ReductionOperation::SUM:
1350 {
1351 // Subtract offsets
1352 auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1353
1354 auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1355 auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1356 auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1357 auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
1358
1359 vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1360 vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1361 vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1362 vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1363
1364 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1365 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
1366
1367 combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x);
1368 break;
1369 }
1370 case ReductionOperation::MEAN_SUM:
1371 {
1372 const auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<float>(in_info.dimension(axis)), wrapper::traits::vector_128_tag{}));
1373 vec_res_value1_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value1), vec_width_inv);
1374 vec_res_value2_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value2), vec_width_inv);
1375 vec_res_value3_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value3), vec_width_inv);
1376 vec_res_value4_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value4), vec_width_inv);
1377
1378 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1379 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1380 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1381 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1382
1383 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1384 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1385 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1386
1387 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1388 break;
1389 }
1390 case ReductionOperation::PROD:
1391 {
1392 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1393 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
1394
1395 //re-quantize
1396 vec_res_value1_f = wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1397 vec_res_value2_f = wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1398 vec_res_value3_f = wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1399 vec_res_value4_f = wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1400
1401 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1402 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1403 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1404 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1405
1406 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1407 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1408 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1409
1410 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1411 break;
1412 }
1413 default:
1414 ARM_COMPUTE_ERROR("Not supported");
1415 }
1416 }
1417
1418 // Compute left-over elements
1419 for(; x < window_end_x; ++x)
1420 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001421 float res_value = 0.f;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001422 switch(op)
1423 {
1424 case ReductionOperation::ARG_IDX_MAX:
1425 case ReductionOperation::ARG_IDX_MIN:
1426 case ReductionOperation::MIN:
1427 case ReductionOperation::MAX:
1428 {
1429 res_value = *(input_ptr + x);
1430 break;
1431 }
1432 case ReductionOperation::PROD:
1433 {
1434 res_value = static_cast<T>(1.0f);
1435 break;
1436 }
1437 default:
1438 {
1439 res_value = static_cast<T>(0.0f);
1440 break;
1441 }
1442 }
1443 uint32_t res_idx = 0;
1444
1445 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1446 {
1447 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim);
1448 switch(op)
1449 {
1450 case ReductionOperation::SUM:
1451 case ReductionOperation::MEAN_SUM:
1452 {
1453 res_value += *in_ptr;
1454 break;
1455 }
1456 case ReductionOperation::SUM_SQUARE:
1457 {
1458 res_value += *in_ptr * *in_ptr;
1459 break;
1460 }
1461 case ReductionOperation::PROD:
1462 {
1463 //de-quantize input
1464 if(std::is_same<T, uint8_t>::value)
1465 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001466 res_value *= dequantize_qasymm8(*in_ptr, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001467 }
1468 else
1469 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001470 res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001471 }
1472 break;
1473 }
1474 case ReductionOperation::ARG_IDX_MIN:
1475 {
1476 if(*in_ptr < res_value)
1477 {
1478 res_value = *in_ptr;
1479 res_idx = dim;
1480 }
1481 break;
1482 }
1483 case ReductionOperation::ARG_IDX_MAX:
1484 {
1485 if(*in_ptr > res_value)
1486 {
1487 res_value = *in_ptr;
1488 res_idx = dim;
1489 }
1490 break;
1491 }
1492 case ReductionOperation::MIN:
1493 {
1494 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1495 break;
1496 }
1497 case ReductionOperation::MAX:
1498 {
1499 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1500 break;
1501 }
1502 default:
1503 ARM_COMPUTE_ERROR("Not supported");
1504 }
1505 }
1506
1507 switch(op)
1508 {
1509 case ReductionOperation::MEAN_SUM:
1510 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001511 int32_t res = static_cast<int32_t>(res_value);
Sang-Hoon Parkcbede282020-10-12 21:44:23 +01001512 res /= static_cast<int32_t>(in_info.dimension(axis));
Michalis Spyrou272e4252020-10-06 17:44:40 +01001513 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001514 break;
1515 }
1516 case ReductionOperation::SUM:
1517 {
1518 // Subtract accumulated offsets
1519 res_value -= (in_info.dimension(axis) - 1) * iq_info.offset;
1520 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
1521 break;
1522 }
1523 case ReductionOperation::PROD:
1524 {
1525 //re-quantize result
Michalis Spyrou272e4252020-10-06 17:44:40 +01001526 T res = 0;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001527 if(std::is_same<T, uint8_t>::value)
1528 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001529 res = quantize_qasymm8(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001530 }
1531 else
1532 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001533 res = quantize_qasymm8_signed(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001534 }
Michalis Spyrou272e4252020-10-06 17:44:40 +01001535 *(reinterpret_cast<T *>(output.ptr() + x)) = res;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001536 break;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001537 }
1538 case ReductionOperation::ARG_IDX_MIN:
1539 case ReductionOperation::ARG_IDX_MAX:
1540 {
1541 *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx;
1542 break;
1543 }
1544 default:
1545 *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
1546 }
1547 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001548 },
1549 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001550 }
1551};
1552
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001553void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001554{
giuros01154bc1c2019-03-26 17:44:40 +00001555 const bool is_complex = (input->info()->num_channels() == 2);
1556
1557 if(is_complex)
1558 {
1559 switch(axis)
1560 {
1561 case 2:
1562 switch(input->info()->data_type())
1563 {
1564 case DataType::F32:
1565 switch(op)
1566 {
1567 case ReductionOperation::SUM:
1568 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1569 default:
1570 ARM_COMPUTE_ERROR("Not supported");
1571 }
1572 default:
1573 ARM_COMPUTE_ERROR("Not supported");
1574 }
1575 default:
1576 ARM_COMPUTE_ERROR("Not supported");
1577 }
1578 }
1579
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001580 switch(axis)
1581 {
1582 case 0:
1583 switch(input->info()->data_type())
1584 {
1585 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001586 return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
1587 case DataType::QASYMM8_SIGNED:
1588 return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001589#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1590 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001591 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001592#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1593 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001594 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001595 case DataType::S32:
1596 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001597 default:
1598 ARM_COMPUTE_ERROR("Not supported");
1599 }
1600 case 1:
1601 switch(input->info()->data_type())
1602 {
1603 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001604 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1605 case DataType::QASYMM8_SIGNED:
1606 return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001607#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1608 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001609 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001610#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1611 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001612 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001613 case DataType::S32:
1614 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001615 default:
1616 ARM_COMPUTE_ERROR("Not supported");
1617 }
1618 case 2:
1619 switch(input->info()->data_type())
1620 {
1621 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001622 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1623 case DataType::QASYMM8_SIGNED:
1624 return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001625#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1626 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001627 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001628#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1629 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001630 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001631 case DataType::S32:
1632 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001633 default:
1634 ARM_COMPUTE_ERROR("Not supported");
1635 }
1636 case 3:
1637 switch(input->info()->data_type())
1638 {
1639 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001640 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1641 case DataType::QASYMM8_SIGNED:
1642 return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001643#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1644 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001645 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001646#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1647 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001648 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001649 case DataType::S32:
1650 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001651 default:
1652 ARM_COMPUTE_ERROR("Not supported");
1653 }
1654 default:
1655 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1656 }
1657}
John Richardson73d4aef2018-05-08 14:34:33 +01001658
1659Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1660{
1661 ARM_COMPUTE_UNUSED(op);
1662
1663 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001664 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001665
1666 if(input->num_channels() == 1)
1667 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001668 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001669 }
1670 else
1671 {
1672 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1673 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1674 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1675 }
John Richardson73d4aef2018-05-08 14:34:33 +01001676
1677 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001678 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001679
1680 if(output->total_size() != 0)
1681 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001682 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1683 if(!is_arg_min_max)
1684 {
1685 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001686 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001687 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001688 }
1689 else
1690 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001691 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001692 }
John Richardson73d4aef2018-05-08 14:34:33 +01001693
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001694 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001695 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1696 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1697 }
1698
1699 return Status{};
1700}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001701} // namespace
1702
1703NEReductionOperationKernel::NEReductionOperationKernel()
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001704 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001705{
1706}
1707
Georgios Pinitasd9769582017-08-03 10:19:40 +01001708void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1709{
1710 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001711
John Richardson73d4aef2018-05-08 14:34:33 +01001712 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001713
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001714 _input = input;
1715 _output = output;
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001716 _op = op;
1717 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001718
1719 // Configure kernel window
Georgios Pinitas412b7892020-11-11 21:05:24 +00001720 Window win = calculate_max_window(*input->info(), Steps());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001721 INEKernel::configure(win);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001722
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001723 // Calculate output shape and set if empty
1724 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis);
1725 // Output auto initialization if not yet initialized
1726 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1727 DataType output_data_type = is_arg_min_max ? DataType::S32 : input->info()->data_type();
1728 auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001729}
1730
1731Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1732{
1733 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
John Richardson73d4aef2018-05-08 14:34:33 +01001734
1735 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001736}
1737
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001738void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001739{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001740 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001741 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1742 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1743
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001744 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001745}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001746} // namespace arm_compute