blob: 19955af4939b5f66be6199cd1323e97f196b06bb [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Mohammed Suhail Munshi470cc5d2023-02-09 11:52:06 +00002 * Copyright (c) 2017-2023 Arm Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEReductionOperationKernel.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010025
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010028#include "arm_compute/core/ITensor.h"
John Richardson73d4aef2018-05-08 14:34:33 +010029#include "arm_compute/core/TensorInfo.h"
Luca Foschianiee939fb2020-01-28 10:38:07 +000030#include "arm_compute/core/Utils.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010031#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000032#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/CPP/Validate.h"
Michalis Spyrouebcebf12020-10-21 00:04:14 +010034#include "src/core/NEON/INEKernel.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010035#include "src/core/NEON/NEMath.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010036#include "src/core/helpers/AutoConfiguration.h"
37#include "src/core/helpers/WindowHelpers.h"
38#include "support/SaturateCast.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010039
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010040#include "src/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010041#include <arm_neon.h>
42
Michalis Spyroubcf8a962018-10-12 10:51:31 +010043namespace arm_compute
44{
Georgios Pinitasd9769582017-08-03 10:19:40 +010045namespace
46{
Luca Foschianiee939fb2020-01-28 10:38:07 +000047// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
48template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +010049void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
Luca Foschianiee939fb2020-01-28 10:38:07 +000050{
51 if(std::is_same<T, uint8_t>::value)
52 {
53 auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010054 wrapper::vstore(output.ptr() + offset, res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000055 }
56 else
57 {
58 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010059 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000060 }
61}
62
Michalis Spyroub9626ab2019-05-13 17:41:01 +010063template <typename T>
64uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000065{
66 uint32x4_t mask{ 0 };
67 if(op == ReductionOperation::ARG_IDX_MIN)
68 {
69 mask = wrapper::vcgt(b, a);
70 }
71 else
72 {
73 mask = wrapper::vclt(b, a);
74 }
75
76 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
77 if(axis != 0)
78 {
79 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
80 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000081 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000082
83 return res;
84}
85
Luca Foschianiee939fb2020-01-28 10:38:07 +000086template <typename T>
87uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000088{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000089 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000090 uint8x16_t mask_u8{ 0 };
91 if(op == ReductionOperation::ARG_IDX_MIN)
92 {
93 mask_u8 = wrapper::vcgt(b, a);
94 }
95 else
96 {
97 mask_u8 = wrapper::vclt(b, a);
98 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000099 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
100 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
101 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
102 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
103 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
104 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
105
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000106 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
107 { idx + 4, idx + 5, idx + 6, idx + 7 },
108 { idx + 8, idx + 9, idx + 10, idx + 11 },
109 { idx + 12, idx + 13, idx + 14, idx + 15 }
110 }
111 };
112 if(axis != 0)
113 {
114 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
115 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
116 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
117 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
118 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000119 uint32x4x4_t res =
120 {
121 {
122 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
123 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
124 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
125 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
126 }
127 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000128
129 return res;
130}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100131
132// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000133template <typename T>
134inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
135 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
136 calculate_min(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100137{
138 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
139 return wrapper::vpmin(pmin, pmin);
140}
141
Luca Foschianiee939fb2020-01-28 10:38:07 +0000142// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
143template <typename T>
144inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
145 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
146 calculate_min(T in)
147{
148 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
149 pmin = wrapper::vpmin(pmin, pmin);
150 pmin = wrapper::vpmin(pmin, pmin);
151 return wrapper::vpmin(pmin, pmin);
152}
153
Usama Arifa4a08ad2019-05-20 12:38:33 +0100154// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000155template <typename T>
156inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
157 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
158 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100159{
160 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
161 return wrapper::vpmax(pmax, pmax);
162}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100163
164// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000165template <typename T>
166inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
167 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
168 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100169{
170 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000171 pmax = wrapper::vpmax(pmax, pmax);
172 pmax = wrapper::vpmax(pmax, pmax);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100173 return wrapper::vpmax(pmax, pmax);
174}
175
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100176template <typename T>
177uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000178{
179 uint32x4_t res_idx_mask{ 0 };
180 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
181
182 if(op == ReductionOperation::ARG_IDX_MIN)
183 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100184 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000185 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
186 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
187 }
188 else
189 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100190 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100191 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000192 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
193 }
194
195 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
196 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
197 pmin = wrapper::vpmin(pmin, pmin);
198 uint32_t res = wrapper::vgetlane(pmin, 0);
199
200 return (res - 0xFFFFFFFF);
201}
202
Luca Foschianiee939fb2020-01-28 10:38:07 +0000203template <typename T>
204uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000205{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000206 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000207 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
208 uint8x16_t mask_u8{ 0 };
209 if(op == ReductionOperation::ARG_IDX_MIN)
210 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100211 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000212 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
213 }
214 else
215 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100216 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000217 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
218 }
219
220 // Widen vectors
221 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
222 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
223 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
224 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
225 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
226 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
227 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
228 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
229 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
230 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
231 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
232 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
233 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
234 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
235
236 uint32_t res = 0xFFFFFFFF;
237 int iter = 0;
238 do
239 {
240 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
241 pmin = wrapper::vpmin(pmin, pmin);
242 res = std::min(wrapper::vgetlane(pmin, 0), res);
243 iter++;
244 }
245 while(iter < 4);
246
247 return (res - 0xFFFFFFFF);
248}
Luca Foschianiee939fb2020-01-28 10:38:07 +0000249
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000250#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100251template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000252uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
253{
254 uint32x4x2_t mask{ 0 };
255 uint16x8_t mask_u16{ 0 };
256 if(op == ReductionOperation::ARG_IDX_MIN)
257 {
258 mask_u16 = wrapper::vcgt(b, a);
259 }
260 else
261 {
262 mask_u16 = wrapper::vclt(b, a);
263 }
264 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
265 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
266 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
267 { idx + 4, idx + 5, idx + 6, idx + 7 }
268 }
269 };
270 if(axis != 0)
271 {
272 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
273 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
274 }
275 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
276 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
277 0, 0
278 };
279
280 return res;
281}
282
Usama Arifa4a08ad2019-05-20 12:38:33 +0100283// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
284inline float16x4_t calculate_min(float16x8_t in)
285{
286 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
287 pmin = wrapper::vpmin(pmin, pmin);
288 return wrapper::vpmin(pmin, pmin);
289}
290// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
291inline float16x4_t calculate_max(float16x8_t in)
292{
293 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
294 pmax = wrapper::vpmax(pmax, pmax);
295 return wrapper::vpmax(pmax, pmax);
296}
297
Usama Arif0a5a57a2019-05-23 14:20:33 +0100298template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000299uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
300{
301 uint32x4x2_t res_idx_mask{ 0 };
302 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
303 uint16x8_t mask_u16;
304 if(op == ReductionOperation::ARG_IDX_MIN)
305 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100306 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000307 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
308 }
309 else
310 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100311 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000312 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
313 }
314
315 // Widen vectors
316 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
317 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
318 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
319 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
320 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
321 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
322
323 uint32_t res = 0xFFFFFFFF;
Michalis Spyrouc89998f2021-08-26 14:11:44 +0100324 uint32_t iter = 0;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000325 do
326 {
327 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
328 pmin = wrapper::vpmin(pmin, pmin);
329 res = std::min(wrapper::vgetlane(pmin, 0), res);
330 iter++;
331 }
332 while(iter < 2);
333
334 return (res - 0xFFFFFFFF);
335}
336#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
337
Georgios Pinitasd9769582017-08-03 10:19:40 +0100338template <class F>
339class Reducer
340{
341public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000342 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100343 {
344 // Set out window
345 Window out_window(window);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100346 out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100347
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100348 f(window, out_window, input, output, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100349 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000350 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100351 {
352 // Set in window
353 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000354 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100355
356 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000357 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100358
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100359 f(in_window, out_window, input, output, 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000361 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100362 {
363 // Set in window
364 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000365 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100366
367 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000368 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100369
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100370 f(in_window, out_window, input, output, 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000372 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100373 {
374 // Set in/out window
375 Window in_window(window);
376 Window out_window(window);
377
378 in_window.set(3, Window::Dimension(0, 1, 1));
379 out_window.set(3, Window::Dimension(0, 1, 1));
380
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100381 f(in_window, out_window, input, output, 3, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100382 }
383};
384
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000385template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100386struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100387{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +0000388 /** SIMD vector tag type. */
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100389 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
390
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100391 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100392 {
Manuel Bottini6a5eee72021-04-30 12:37:04 +0100393 const size_t input_dim_0 = in->info()->dimension(0);
394 const int window_step_x = 16 / sizeof(T);
395 const auto window_start_x = static_cast<int>(in_window.x().start());
396 const auto window_end_x = static_cast<int>(in_window.x().end());
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100397
Georgios Pinitas412b7892020-11-11 21:05:24 +0000398 Window in_win_no_pad = in_window;
399 in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100400
Georgios Pinitas412b7892020-11-11 21:05:24 +0000401 Iterator input(in, in_win_no_pad);
402 Iterator output(out, out_window);
403
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000404 execute_window_loop(
405 in_win_no_pad, [&](const Coordinates &)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000406 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100407 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
Georgios Pinitasd9769582017-08-03 10:19:40 +0100408
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100409 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100410 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000411 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100412 case ReductionOperation::ARG_IDX_MAX:
413 case ReductionOperation::ARG_IDX_MIN:
414 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100415 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100416 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100417 init_res_value = static_cast<T>(*input_ptr);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100418 break;
419 }
420 case ReductionOperation::PROD:
421 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100422 init_res_value = static_cast<T>(1.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100423 break;
424 }
425 default:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100426 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000427 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100428 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000429 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000430
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100431 // Compute window_step_x elements per iteration
432 int x = window_start_x;
433 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100434 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100435 const auto vec_elements = wrapper::vloadq(input_ptr + x);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000436 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100437 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000438 case ReductionOperation::SUM_SQUARE:
439 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
440 break;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100441 case ReductionOperation::MEAN_SUM:
442 case ReductionOperation::SUM:
443 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
444 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000445 case ReductionOperation::PROD:
446 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
447 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000448 case ReductionOperation::ARG_IDX_MIN:
449 {
450 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100451 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000452 vec_res_value = temp_vec_res_value;
453 break;
454 }
455 case ReductionOperation::ARG_IDX_MAX:
456 {
457 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100458 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000459 vec_res_value = temp_vec_res_value;
460 break;
461 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100462 case ReductionOperation::MIN:
463 {
464 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
465 break;
466 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100467 case ReductionOperation::MAX:
468 {
469 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
470 break;
471 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000472 default:
473 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100474 }
475 }
476
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100477 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100478 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100479 case ReductionOperation::SUM:
480 case ReductionOperation::MEAN_SUM:
481 case ReductionOperation::SUM_SQUARE:
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100482 {
Manuel Bottini6a5eee72021-04-30 12:37:04 +0100483#ifdef ARM_COMPUTE_DEBUG_ENABLED
484 auto res = static_cast<T>(0.f);
485 for(int i = 0; i < S; ++i)
486 {
487 res += wrapper::vgetlane(vec_res_value, i);
488 }
489#else // ARM_COMPUTE_DEBUG_ENABLED
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100490 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
491 for(int i = 0; i < S / 4; ++i)
492 {
493 carry_res = wrapper::vpadd(carry_res, carry_res);
494 }
495 auto res = wrapper::vgetlane(carry_res, 0);
Manuel Bottini6a5eee72021-04-30 12:37:04 +0100496#endif // ARM_COMPUTE_DEBUG_ENABLED
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100497 if(op == ReductionOperation::SUM_SQUARE)
498 {
499 // Compute left-over elements
500 for(; x < window_end_x; ++x)
501 {
502 res += (*(input_ptr + x)) * (*(input_ptr + x));
503 }
504 }
505 else
506 {
507 // Compute left-over elements
508 for(; x < window_end_x; ++x)
509 {
510 res += *(input_ptr + x);
511 }
512 }
513
514 if(op == ReductionOperation::MEAN_SUM)
515 {
Manuel Bottini6a5eee72021-04-30 12:37:04 +0100516 res /= input_dim_0;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100517 }
518
519 *(reinterpret_cast<T *>(output.ptr())) = res;
520 break;
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100521 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100522 case ReductionOperation::PROD:
giuros01154bc1c2019-03-26 17:44:40 +0000523 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100524 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
525 T res = 1;
526 for(int i = 0; i < S / 2; ++i)
527 {
528 res *= wrapper::vgetlane(carry_res, i);
529 }
giuros01154bc1c2019-03-26 17:44:40 +0000530
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100531 // Compute left-over elements
532 for(; x < window_end_x; ++x)
533 {
534 res *= *(input_ptr + x);
535 }
536
537 *(reinterpret_cast<T *>(output.ptr())) = res;
538 break;
539 }
540 case ReductionOperation::ARG_IDX_MIN:
giuros01154bc1c2019-03-26 17:44:40 +0000541 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100542 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
543 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
544
545 // Compute left-over elements
546 for(; x < window_end_x; ++x)
547 {
548 if(*(input_ptr + x) < res)
549 {
550 idx = x;
551 res = *(input_ptr + x);
552 }
553 }
554 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
555 break;
giuros01154bc1c2019-03-26 17:44:40 +0000556 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100557 case ReductionOperation::ARG_IDX_MAX:
558 {
559 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
560 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
561
562 // Compute left-over elements
563 for(; x < window_end_x; ++x)
564 {
565 if(*(input_ptr + x) > res)
566 {
567 idx = x;
568 res = *(input_ptr + x);
569 }
570 }
571 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
572 break;
573 }
574 case ReductionOperation::MIN:
575 {
576 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
577
578 // Compute left-over elements
579 for(; x < window_end_x; ++x)
580 {
581 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
582 }
583 *(reinterpret_cast<T *>(output.ptr())) = res;
584 break;
585 }
586 case ReductionOperation::MAX:
587 {
588 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
589
590 // Compute left-over elements
591 for(; x < window_end_x; ++x)
592 {
593 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
594 }
595 *(reinterpret_cast<T *>(output.ptr())) = res;
596 break;
597 }
598 default:
599 ARM_COMPUTE_ERROR("Not supported");
giuros01154bc1c2019-03-26 17:44:40 +0000600 }
giuros01154bc1c2019-03-26 17:44:40 +0000601 },
602 input, output);
603 }
604};
605
Luca Foschianiee939fb2020-01-28 10:38:07 +0000606template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100607struct RedOpX_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100608{
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100609 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100610 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000611 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
612
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000613 const auto oq_info = out->info()->quantization_info().uniform();
614
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100615 const TensorInfo in_info = *(in->info());
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100616 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
617
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100618 const int window_step_x = 16 / sizeof(T);
619 const auto window_start_x = static_cast<int>(in_window.x().start());
620 const auto window_end_x = static_cast<int>(in_window.x().end());
621
Georgios Pinitas412b7892020-11-11 21:05:24 +0000622 Window in_win_no_pad = in_window;
623 in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
624
625 Iterator input(in, in_win_no_pad);
626 Iterator output(out, out_window);
627
Mohammed Suhail Munshi470cc5d2023-02-09 11:52:06 +0000628 const auto in_offset = static_cast<float>(iq_info.offset);
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000629 const float in_scale = iq_info.scale;
630
Mohammed Suhail Munshi470cc5d2023-02-09 11:52:06 +0000631 const auto out_offset = static_cast<float>(oq_info.offset);
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000632 const float out_scale = oq_info.scale;
633
Mohammed Suhail Munshi470cc5d2023-02-09 11:52:06 +0000634 const auto num_elements = static_cast<float>(in_info.dimension(0));
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000635
636 const float A = in_scale / (out_scale * num_elements);
637 const float B = out_offset - (in_scale * in_offset) / (out_scale);
638
639 execute_window_loop(
640 in_win_no_pad, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100641 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100642 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000643
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100644 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
645 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
646 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
647 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000648
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100649 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
650 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
651 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
652 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000653
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100654 typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = { 0 };
655
656 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100657 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100658 vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{});
659 }
660
661 uint32x4x4_t vec_res_idx{ { 0 } };
662 // Compute window_step_x elements per iteration
663 int x = window_start_x;
664 for(; x <= (window_end_x - window_step_x); x += window_step_x)
665 {
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100666 const auto vec_elements = wrapper::vloadq(input_ptr + x);
667 switch(op)
668 {
669 case ReductionOperation::SUM:
670 case ReductionOperation::MEAN_SUM:
671 {
672 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
673 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100674
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100675 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
676 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
677 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
678 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
679
680 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
681 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
682 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
683 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
684 break;
685 }
686 case ReductionOperation::PROD:
687 {
688 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
689 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
690
691 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
692 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
693
694 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
695 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
696 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
697 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
698
699 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
700 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
701 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
702 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
703
704 //de-quantize vec_elements
705 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
706 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
707 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
708 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
709
710 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
711 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
712 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
713 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
714 break;
715 }
716 case ReductionOperation::ARG_IDX_MIN:
717 {
718 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
719 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
720 vec_res_value = temp_vec_res_value;
721 break;
722 }
723 case ReductionOperation::ARG_IDX_MAX:
724 {
725 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
726 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
727 vec_res_value = temp_vec_res_value;
728 break;
729 }
730 case ReductionOperation::MIN:
731 {
732 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
733 break;
734 }
735 case ReductionOperation::MAX:
736 {
737 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
738 break;
739 }
740 default:
741 ARM_COMPUTE_ERROR("Not supported");
742 }
743 }
744
745 switch(op)
746 {
747 case ReductionOperation::ARG_IDX_MIN:
748 {
749 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
750 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
751
752 // Compute left-over elements
753 for(; x < window_end_x; ++x)
754 {
755 if(*(input_ptr + x) < res)
756 {
757 idx = x;
758 res = *(input_ptr + x);
759 }
760 }
761 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
762 break;
763 }
764 case ReductionOperation::ARG_IDX_MAX:
765 {
766 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
767 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
768
769 // Compute left-over elements
770 for(; x < window_end_x; ++x)
771 {
772 if(*(input_ptr + x) > res)
773 {
774 idx = x;
775 res = *(input_ptr + x);
776 }
777 }
778 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
779 break;
780 }
781 case ReductionOperation::MIN:
782 {
783 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
784
785 // Compute left-over elements
786 for(; x < window_end_x; ++x)
787 {
788 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
789 }
790 *(reinterpret_cast<T *>(output.ptr())) = res;
791 break;
792 }
793 case ReductionOperation::MAX:
794 {
795 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
796
797 // Compute left-over elements
798 for(; x < window_end_x; ++x)
799 {
800 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
801 }
802 *(reinterpret_cast<T *>(output.ptr())) = res;
803 break;
804 }
805 case ReductionOperation::PROD:
806 {
807 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
808 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
809 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
810
811 float res = wrapper::vgetlane(carry_res, 0);
812 res *= wrapper::vgetlane(carry_res, 1);
813 res *= wrapper::vgetlane(carry_res, 2);
814 res *= wrapper::vgetlane(carry_res, 3);
815
816 // Compute left-over elements
817 for(; x < window_end_x; ++x)
818 {
819 //de-quantize input
820 if(std::is_same<T, uint8_t>::value)
821 {
822 res *= dequantize_qasymm8(*(input_ptr + x), iq_info);
823 }
824 else
825 {
826 res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info);
827 }
828 }
829
830 //re-quantize result
831 if(std::is_same<T, uint8_t>::value)
832 {
833 res = quantize_qasymm8(res, iq_info);
834 }
835 else
836 {
837 res = quantize_qasymm8_signed(res, iq_info);
838 }
839
840 *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
841 break;
842 }
843 case ReductionOperation::SUM:
844 case ReductionOperation::MEAN_SUM:
845 {
846 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
847 carry_res = wrapper::vadd(carry_res, vec_res_value3);
848 carry_res = wrapper::vadd(carry_res, vec_res_value4);
849
850 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
851 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
852 auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
853
854 // Compute left-over elements
855 for(; x < window_end_x; ++x)
856 {
857 res += *(input_ptr + x);
858 }
859
860 if(op == ReductionOperation::MEAN_SUM)
861 {
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000862 const int32_t resFinal = A * (static_cast<float>(res)) + B;
863
864 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal);
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100865 }
866 else
867 {
868 // Subtract accumulated offsets
869 res -= (in_info.dimension(0) - 1) * iq_info.offset;
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000870 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100871 }
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000872
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100873 break;
874 }
875 default:
876 ARM_COMPUTE_ERROR("Not supported");
877 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100878 },
879 input, output);
880 }
881};
882
883template <typename T, int S>
884struct RedOpYZW
885{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +0000886 /** SIMD vector tag type. */
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100887 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
888 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
889
890 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
891 {
Georgios Pinitas412b7892020-11-11 21:05:24 +0000892 const TensorInfo in_info = *(in->info());
893 const int window_step_x = 16 / sizeof(T);
894 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
895 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
896 // As it split over x-axis, need to set the correct spiltted window start and end.
897 const auto window_start_x = static_cast<int>(0);
898 const auto window_end_x = static_cast<int>(in_window.shape().x());
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100899
Georgios Pinitas412b7892020-11-11 21:05:24 +0000900 Window in_win_no_pad = in_window;
901 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
902 Window out_win_no_pad = out_window;
903 out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100904
Georgios Pinitas412b7892020-11-11 21:05:24 +0000905 Iterator input(in, in_win_no_pad);
906 Iterator output(out, out_win_no_pad);
907
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000908 execute_window_loop(
909 in_win_no_pad, [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100910 {
911 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
912
913 // Compute window_step_x elements per iteration
914 int x = window_start_x;
915 for(; x <= (window_end_x - window_step_x); x += window_step_x)
916 {
917 neon_vector vec_res_value = { 0 };
918 switch(op)
919 {
920 case ReductionOperation::ARG_IDX_MAX:
921 case ReductionOperation::ARG_IDX_MIN:
922 case ReductionOperation::MIN:
923 case ReductionOperation::MAX:
924 {
925 vec_res_value = wrapper::vloadq(input_ptr + x);
926 break;
927 }
928 case ReductionOperation::PROD:
929 {
930 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
931 break;
932 }
933 default:
934 {
935 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
936 break;
937 }
938 }
939 uint32x4x4_t vec_res_idx{ { 0 } };
940
941 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
942 {
943 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
944 const auto vec_elements = wrapper::vloadq(in_ptr);
945 switch(op)
946 {
947 case ReductionOperation::SUM:
948 case ReductionOperation::MEAN_SUM:
949 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
950 break;
951 case ReductionOperation::SUM_SQUARE:
952 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
953 break;
954 case ReductionOperation::PROD:
955 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
956 break;
957 case ReductionOperation::ARG_IDX_MIN:
958 {
959 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
960 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
961 vec_res_value = temp_vec_res_value;
962 break;
963 }
964 case ReductionOperation::ARG_IDX_MAX:
965 {
966 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
967 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
968 vec_res_value = temp_vec_res_value;
969 break;
970 }
971 case ReductionOperation::MIN:
972 {
973 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
974 break;
975 }
976 case ReductionOperation::MAX:
977 {
978 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
979 break;
980 }
981 default:
982 ARM_COMPUTE_ERROR("Not supported");
983 }
984 }
985
986 if(op == ReductionOperation::MEAN_SUM)
987 {
988 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
989 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
990 }
991
992 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
993 {
994 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
995#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
996 if(std::is_same<T, float16_t>::value)
997 {
998 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
999 }
1000#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001001 }
1002 else
1003 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001004 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
Michalis Spyrou19bd4122020-01-22 10:27:06 +00001005 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001006 }
1007
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001008 // Compute left-over elements
1009 for(; x < window_end_x; ++x)
1010 {
1011 auto res_value = 0.f;
1012 switch(op)
1013 {
1014 case ReductionOperation::ARG_IDX_MAX:
1015 case ReductionOperation::ARG_IDX_MIN:
1016 case ReductionOperation::MIN:
1017 case ReductionOperation::MAX:
1018 {
1019 res_value = *(input_ptr + x);
1020 break;
1021 }
1022 case ReductionOperation::PROD:
1023 {
1024 res_value = static_cast<T>(1.f);
1025 break;
1026 }
1027 default:
1028 {
1029 res_value = static_cast<T>(0.f);
1030 break;
1031 }
1032 }
1033
1034 uint32_t res_idx = 0;
1035 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1036 {
1037 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
1038
1039 switch(op)
1040 {
1041 case ReductionOperation::SUM:
1042 case ReductionOperation::MEAN_SUM:
1043 res_value += *in_ptr;
1044 break;
1045 case ReductionOperation::SUM_SQUARE:
1046 res_value += *in_ptr * *in_ptr;
1047 break;
1048 case ReductionOperation::PROD:
1049 res_value *= *in_ptr;
1050 break;
1051 case ReductionOperation::ARG_IDX_MIN:
1052 {
1053 if(*in_ptr < res_value)
1054 {
1055 res_value = *in_ptr;
1056 res_idx = dim;
1057 }
1058 break;
1059 }
1060 case ReductionOperation::ARG_IDX_MAX:
1061 {
1062 if(*in_ptr > res_value)
1063 {
1064 res_value = *in_ptr;
1065 res_idx = dim;
1066 }
1067 break;
1068 }
1069 case ReductionOperation::MIN:
1070 {
1071 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1072 break;
1073 }
1074 case ReductionOperation::MAX:
1075 {
1076 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1077 break;
1078 }
1079 default:
1080 ARM_COMPUTE_ERROR("Not supported");
1081 }
1082 }
1083
1084 if(op == ReductionOperation::MEAN_SUM)
1085 {
1086 res_value /= in_info.dimension(axis);
1087 }
1088
1089 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1090 {
1091 *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
1092 }
1093 else
1094 {
1095 *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
1096 }
1097 }
1098 },
1099 input, output);
1100 }
1101};
1102
1103template <typename T, int S, int axis, ReductionOperation op>
1104struct RedOpYZW_complex
1105{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +00001106 /** SIMD vector tag type. */
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001107 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
1108 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
1109
1110 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
1111 {
1112 ARM_COMPUTE_ERROR_ON(axis != 2);
Georgios Pinitas412b7892020-11-11 21:05:24 +00001113 ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001114
Georgios Pinitas412b7892020-11-11 21:05:24 +00001115 const TensorInfo in_info = *(in->info());
1116 const size_t stride_z = in_info.strides_in_bytes()[axis];
1117 const int window_step_x = 16 / sizeof(T);
1118 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
1119 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
1120 // As it split over x-axis, need to set the correct spiltted window start and end.
1121 const auto window_start_x = static_cast<int>(0);
1122 const auto window_end_x = static_cast<int>(in_window.shape().x());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001123
Georgios Pinitas412b7892020-11-11 21:05:24 +00001124 Window in_win_no_pad = in_window;
1125 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1126 Window out_win_no_pad = out_window;
1127 out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001128
Georgios Pinitas412b7892020-11-11 21:05:24 +00001129 Iterator input(in, in_win_no_pad);
1130 Iterator output(out, out_win_no_pad);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001131
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001132 execute_window_loop(
1133 in_win_no_pad, [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001134 {
1135 // Compute window_step_x elements per iteration
1136 int x = window_start_x;
1137 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1138 {
1139 neon_vector vec_res_value_0 = { 0 };
1140 neon_vector vec_res_value_1 = { 0 };
1141
1142 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1143 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1144
1145 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1146 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1147 {
Georgios Pinitas412b7892020-11-11 21:05:24 +00001148 T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1149 T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
1150
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001151 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
1152 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
1153
Georgios Pinitas412b7892020-11-11 21:05:24 +00001154 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
1155 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001156 }
1157
1158 wrapper::vstore(out_ptr, vec_res_value_0);
1159 wrapper::vstore(out_ptr + 4, vec_res_value_1);
1160 }
1161
1162 // Compute left-over elements
1163 for(; x < window_end_x; ++x)
1164 {
1165 auto res_value_0 = 0.f;
1166 auto res_value_1 = 0.f;
1167
1168 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1169 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1170 {
Georgios Pinitas412b7892020-11-11 21:05:24 +00001171 T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1172 res_value_0 += *in_ptr;
1173 res_value_1 += *(in_ptr + 1);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001174 }
1175 *out_ptr = res_value_0;
1176 *(out_ptr + 1) = res_value_1;
1177 }
1178 },
1179 input, output);
1180 }
1181};
1182
1183template <typename T>
1184struct RedOpYZW_quantized
1185{
1186 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
1187 {
Georgios Pinitas412b7892020-11-11 21:05:24 +00001188 const TensorInfo in_info = *(in->info());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001189 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
Georgios Pinitas412b7892020-11-11 21:05:24 +00001190 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001191
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001192 const auto oq_info = out->info()->quantization_info().uniform();
1193
Georgios Pinitas412b7892020-11-11 21:05:24 +00001194 const int window_step_x = 16 / sizeof(T);
1195 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
1196 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
1197 // As it split over x-axis, need to set the correct spiltted window start and end.
1198 const auto window_start_x = static_cast<int>(0);
1199 const auto window_end_x = static_cast<int>(in_window.shape().x());
1200
1201 Window in_win_no_pad = in_window;
1202 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1203 Window out_win_no_pad = out_window;
1204 out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
1205
1206 Iterator input(in, in_win_no_pad);
1207 Iterator output(out, out_win_no_pad);
1208
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +01001209 using vector_type = typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type;
1210 using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type;
1211
1212 vector_type vec_res_value1{};
1213 vector_type vec_res_value2{};
1214 vector_type vec_res_value3{};
1215 vector_type vec_res_value4{};
1216
1217 vector_type_f vec_res_value1_f{};
1218 vector_type_f vec_res_value2_f{};
1219 vector_type_f vec_res_value3_f{};
1220 vector_type_f vec_res_value4_f{};
1221
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001222 const float in_offset = static_cast<float>(iq_info.offset);
1223 const float in_scale = iq_info.scale;
1224
1225 const float out_offset = static_cast<float>(oq_info.offset);
1226 const float out_scale = oq_info.scale;
1227
1228 const float num_elements = static_cast<float>(in_info.dimension(axis));
1229
1230 const float A = in_scale / (out_scale * num_elements);
1231 const float B = out_offset - (in_scale * in_offset) / (out_scale);
1232
1233 const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{});
1234 const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{});
1235
1236 execute_window_loop(
1237 in_win_no_pad, [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001238 {
1239 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
1240
1241 // Compute window_step_x elements per iteration
1242 int x = window_start_x;
1243 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1244 {
1245 uint32x4x4_t vec_res_idx{ { 0 } };
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +01001246 vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1247 vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1248 vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1249 vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001250
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +01001251 vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1252 vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1253 vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1254 vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001255
1256 auto vec_res_value = wrapper::vloadq(input_ptr + x);
1257
1258 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
1259 {
1260 const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim;
1261 const auto vec_elements = wrapper::vloadq(in_ptr);
1262 switch(op)
1263 {
1264 case ReductionOperation::SUM:
1265 case ReductionOperation::MEAN_SUM:
1266 {
1267 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1268 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1269
1270 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1271 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1272 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1273 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1274
1275 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
1276 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
1277 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
1278 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
1279 break;
1280 }
1281 case ReductionOperation::PROD:
1282 {
1283 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1284 const auto scale32x4f_4 = wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
1285
1286 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1287 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1288
1289 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1290 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1291 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1292 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1293
1294 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
1295 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
1296 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
1297 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
1298
1299 //de-quantize vec_elements
1300 temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
1301 temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
1302 temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
1303 temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
1304
1305 vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
1306 vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
1307 vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
1308 vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
1309 break;
1310 }
1311 case ReductionOperation::ARG_IDX_MIN:
1312 {
1313 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1314 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1315 vec_res_value = temp_vec_res_value;
1316 break;
1317 }
1318 case ReductionOperation::ARG_IDX_MAX:
1319 {
1320 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1321 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1322 vec_res_value = temp_vec_res_value;
1323 break;
1324 }
1325 case ReductionOperation::MIN:
1326 {
1327 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1328 break;
1329 }
1330 case ReductionOperation::MAX:
1331 {
1332 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1333 break;
1334 }
1335 default:
1336 ARM_COMPUTE_ERROR("Not supported");
1337 }
1338 }
1339
1340 switch(op)
1341 {
1342 case ReductionOperation::ARG_IDX_MIN:
1343 case ReductionOperation::ARG_IDX_MAX:
1344 {
1345 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]);
1346 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]);
1347 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]);
1348 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, vec_res_idx.val[3]);
1349 break;
1350 }
1351 case ReductionOperation::MIN:
1352 case ReductionOperation::MAX:
1353 {
1354 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value);
1355 break;
1356 }
1357 case ReductionOperation::SUM:
1358 {
1359 // Subtract offsets
1360 auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1361
1362 auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1363 auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1364 auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1365 auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
1366
1367 vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1368 vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1369 vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1370 vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1371
1372 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1373 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
1374
1375 combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x);
1376 break;
1377 }
1378 case ReductionOperation::MEAN_SUM:
1379 {
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001380 vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A);
1381 vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A);
1382 vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A);
1383 vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001384
Mohammed Suhail Munshi470cc5d2023-02-09 11:52:06 +00001385#ifdef __aarch64__
1386 vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f);
1387 vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f);
1388 vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f);
1389 vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f);
1390#else // defined(__aarch64__)
1391 vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f);
1392 vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f);
1393 vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f);
1394 vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f);
1395#endif // __aarch64__
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001396
1397 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1398 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1399 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1400
1401 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1402 break;
1403 }
1404 case ReductionOperation::PROD:
1405 {
1406 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1407 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
1408
1409 //re-quantize
1410 vec_res_value1_f = wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1411 vec_res_value2_f = wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1412 vec_res_value3_f = wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1413 vec_res_value4_f = wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1414
1415 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1416 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1417 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1418 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1419
1420 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1421 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1422 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1423
1424 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1425 break;
1426 }
1427 default:
1428 ARM_COMPUTE_ERROR("Not supported");
1429 }
1430 }
1431
1432 // Compute left-over elements
1433 for(; x < window_end_x; ++x)
1434 {
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001435 float res_value = 0.f;
1436 int32_t res_value_q = 0;
1437
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001438 switch(op)
1439 {
1440 case ReductionOperation::ARG_IDX_MAX:
1441 case ReductionOperation::ARG_IDX_MIN:
1442 case ReductionOperation::MIN:
1443 case ReductionOperation::MAX:
1444 {
1445 res_value = *(input_ptr + x);
1446 break;
1447 }
1448 case ReductionOperation::PROD:
1449 {
1450 res_value = static_cast<T>(1.0f);
1451 break;
1452 }
1453 default:
1454 {
1455 res_value = static_cast<T>(0.0f);
1456 break;
1457 }
1458 }
1459 uint32_t res_idx = 0;
1460
1461 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1462 {
1463 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim);
1464 switch(op)
1465 {
1466 case ReductionOperation::SUM:
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001467 {
1468 res_value += *in_ptr;
1469 break;
1470 }
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001471 case ReductionOperation::MEAN_SUM:
1472 {
1473 res_value_q += *in_ptr;
1474 break;
1475 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001476 case ReductionOperation::SUM_SQUARE:
1477 {
1478 res_value += *in_ptr * *in_ptr;
1479 break;
1480 }
1481 case ReductionOperation::PROD:
1482 {
1483 //de-quantize input
1484 if(std::is_same<T, uint8_t>::value)
1485 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001486 res_value *= dequantize_qasymm8(*in_ptr, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001487 }
1488 else
1489 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001490 res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001491 }
1492 break;
1493 }
1494 case ReductionOperation::ARG_IDX_MIN:
1495 {
1496 if(*in_ptr < res_value)
1497 {
1498 res_value = *in_ptr;
1499 res_idx = dim;
1500 }
1501 break;
1502 }
1503 case ReductionOperation::ARG_IDX_MAX:
1504 {
1505 if(*in_ptr > res_value)
1506 {
1507 res_value = *in_ptr;
1508 res_idx = dim;
1509 }
1510 break;
1511 }
1512 case ReductionOperation::MIN:
1513 {
1514 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1515 break;
1516 }
1517 case ReductionOperation::MAX:
1518 {
1519 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1520 break;
1521 }
1522 default:
1523 ARM_COMPUTE_ERROR("Not supported");
1524 }
1525 }
1526
1527 switch(op)
1528 {
1529 case ReductionOperation::MEAN_SUM:
1530 {
Mohammed Suhail Munshi470cc5d2023-02-09 11:52:06 +00001531 // Apply previously calculated coefficients (with rounding on aarch64)
1532#ifdef __aarch64__
1533 const int32_t res = arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B);
1534#else // defined(__aarch64__)
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001535 const int32_t res = A * (static_cast<float>(res_value_q)) + B;
Mohammed Suhail Munshi470cc5d2023-02-09 11:52:06 +00001536#endif // __aarch64__
Michalis Spyrou272e4252020-10-06 17:44:40 +01001537 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001538 break;
1539 }
1540 case ReductionOperation::SUM:
1541 {
1542 // Subtract accumulated offsets
1543 res_value -= (in_info.dimension(axis) - 1) * iq_info.offset;
1544 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
1545 break;
1546 }
1547 case ReductionOperation::PROD:
1548 {
1549 //re-quantize result
Michalis Spyrou272e4252020-10-06 17:44:40 +01001550 T res = 0;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001551 if(std::is_same<T, uint8_t>::value)
1552 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001553 res = quantize_qasymm8(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001554 }
1555 else
1556 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001557 res = quantize_qasymm8_signed(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001558 }
Michalis Spyrou272e4252020-10-06 17:44:40 +01001559 *(reinterpret_cast<T *>(output.ptr() + x)) = res;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001560 break;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001561 }
1562 case ReductionOperation::ARG_IDX_MIN:
1563 case ReductionOperation::ARG_IDX_MAX:
1564 {
1565 *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx;
1566 break;
1567 }
1568 default:
1569 *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
1570 }
1571 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001572 },
1573 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001574 }
1575};
1576
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001577void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001578{
giuros01154bc1c2019-03-26 17:44:40 +00001579 const bool is_complex = (input->info()->num_channels() == 2);
1580
1581 if(is_complex)
1582 {
1583 switch(axis)
1584 {
1585 case 2:
1586 switch(input->info()->data_type())
1587 {
1588 case DataType::F32:
1589 switch(op)
1590 {
1591 case ReductionOperation::SUM:
1592 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1593 default:
1594 ARM_COMPUTE_ERROR("Not supported");
1595 }
1596 default:
1597 ARM_COMPUTE_ERROR("Not supported");
1598 }
1599 default:
1600 ARM_COMPUTE_ERROR("Not supported");
1601 }
Manuel Bottini6a5eee72021-04-30 12:37:04 +01001602 return;
giuros01154bc1c2019-03-26 17:44:40 +00001603 }
1604
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001605 switch(axis)
1606 {
1607 case 0:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001608 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001609 switch(input->info()->data_type())
1610 {
1611 case DataType::QASYMM8:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001612 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001613 return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001614 }
Luca Foschianiee939fb2020-01-28 10:38:07 +00001615 case DataType::QASYMM8_SIGNED:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001616 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001617 return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001618 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001619#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1620 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001621 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001622#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1623 case DataType::F32:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001624 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001625 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001626 }
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001627 case DataType::S32:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001628 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001629 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001630 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001631 default:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001632 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001633 ARM_COMPUTE_ERROR("Not supported");
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001634 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001635 }
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001636 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001637 case 1:
1638 switch(input->info()->data_type())
1639 {
1640 case DataType::QASYMM8:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001641 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001642 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001643 }
Luca Foschianiee939fb2020-01-28 10:38:07 +00001644 case DataType::QASYMM8_SIGNED:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001645 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001646 return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001647 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001648#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1649 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001650 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001651#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1652 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001653 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001654 case DataType::S32:
1655 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001656 default:
1657 ARM_COMPUTE_ERROR("Not supported");
1658 }
1659 case 2:
1660 switch(input->info()->data_type())
1661 {
1662 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001663 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1664 case DataType::QASYMM8_SIGNED:
1665 return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001666#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1667 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001668 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001669#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1670 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001671 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001672 case DataType::S32:
1673 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001674 default:
1675 ARM_COMPUTE_ERROR("Not supported");
1676 }
1677 case 3:
1678 switch(input->info()->data_type())
1679 {
1680 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001681 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1682 case DataType::QASYMM8_SIGNED:
1683 return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001684#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1685 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001686 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001687#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1688 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001689 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001690 case DataType::S32:
1691 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001692 default:
1693 ARM_COMPUTE_ERROR("Not supported");
1694 }
1695 default:
1696 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1697 }
1698}
John Richardson73d4aef2018-05-08 14:34:33 +01001699
1700Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1701{
1702 ARM_COMPUTE_UNUSED(op);
1703
1704 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001705 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001706
1707 if(input->num_channels() == 1)
1708 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001709 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001710 }
1711 else
1712 {
1713 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1714 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1715 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1716 }
John Richardson73d4aef2018-05-08 14:34:33 +01001717
1718 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001719 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001720
1721 if(output->total_size() != 0)
1722 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001723 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1724 if(!is_arg_min_max)
1725 {
1726 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001727 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001728 }
1729 else
1730 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001731 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001732 }
John Richardson73d4aef2018-05-08 14:34:33 +01001733
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001734 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001735 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1736 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1737 }
1738
1739 return Status{};
1740}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001741} // namespace
1742
1743NEReductionOperationKernel::NEReductionOperationKernel()
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001744 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001745{
1746}
1747
Georgios Pinitasd9769582017-08-03 10:19:40 +01001748void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1749{
1750 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001751
John Richardson73d4aef2018-05-08 14:34:33 +01001752 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001753
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001754 _input = input;
1755 _output = output;
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001756 _op = op;
1757 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001758
1759 // Configure kernel window
Georgios Pinitas412b7892020-11-11 21:05:24 +00001760 Window win = calculate_max_window(*input->info(), Steps());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001761 INEKernel::configure(win);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001762
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001763 // Calculate output shape and set if empty
1764 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis);
1765 // Output auto initialization if not yet initialized
1766 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1767 DataType output_data_type = is_arg_min_max ? DataType::S32 : input->info()->data_type();
1768 auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001769}
1770
1771Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1772{
1773 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
John Richardson73d4aef2018-05-08 14:34:33 +01001774
1775 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001776}
1777
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001778void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001779{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001780 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001781 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1782 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1783
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001784 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001785}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001786} // namespace arm_compute