blob: 455d604b3b53fd59eb2143d74ebd50da75c02bf1 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Mohammed Suhail Munshiecaa10a2023-02-09 11:52:06 +00002 * Copyright (c) 2017-2023 Arm Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEReductionOperationKernel.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010025
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010028#include "arm_compute/core/ITensor.h"
John Richardson73d4aef2018-05-08 14:34:33 +010029#include "arm_compute/core/TensorInfo.h"
Luca Foschianiee939fb2020-01-28 10:38:07 +000030#include "arm_compute/core/Utils.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000031#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010032#include "arm_compute/core/Validate.h"
33
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010035#include "src/core/helpers/AutoConfiguration.h"
36#include "src/core/helpers/WindowHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010037#include "src/core/NEON/INEKernel.h"
38#include "src/core/NEON/NEMath.h"
39#include "src/core/NEON/wrapper/wrapper.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010040#include "support/SaturateCast.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010041
42#include <arm_neon.h>
43
Michalis Spyroubcf8a962018-10-12 10:51:31 +010044namespace arm_compute
45{
Georgios Pinitasd9769582017-08-03 10:19:40 +010046namespace
47{
Luca Foschianiee939fb2020-01-28 10:38:07 +000048// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
49template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +010050void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
Luca Foschianiee939fb2020-01-28 10:38:07 +000051{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010052 if (std::is_same<T, uint8_t>::value)
Luca Foschianiee939fb2020-01-28 10:38:07 +000053 {
54 auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010055 wrapper::vstore(output.ptr() + offset, res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000056 }
57 else
58 {
59 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010060 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000061 }
62}
63
Michalis Spyroub9626ab2019-05-13 17:41:01 +010064template <typename T>
65uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000066{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010067 uint32x4_t mask{0};
68 if (op == ReductionOperation::ARG_IDX_MIN)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000069 {
70 mask = wrapper::vcgt(b, a);
71 }
72 else
73 {
74 mask = wrapper::vclt(b, a);
75 }
76
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010077 uint32x4_t vec_idx = {idx, idx + 1, idx + 2, idx + 3};
78 if (axis != 0)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000079 {
80 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
81 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010082 uint32x4x4_t res = {{wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0}};
Michalis Spyrouaea14c62019-01-03 11:10:25 +000083
84 return res;
85}
86
Luca Foschianiee939fb2020-01-28 10:38:07 +000087template <typename T>
88uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000089{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010090 uint32x4x4_t mask{{0}};
91 uint8x16_t mask_u8{0};
92 if (op == ReductionOperation::ARG_IDX_MIN)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000093 {
94 mask_u8 = wrapper::vcgt(b, a);
95 }
96 else
97 {
98 mask_u8 = wrapper::vclt(b, a);
99 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100100 auto wide_u16_1 =
101 wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
102 auto wide_u16_2 =
103 wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
104 mask.val[0] =
105 wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
106 mask.val[1] =
107 wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
108 mask.val[2] =
109 wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
110 mask.val[3] =
111 wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
Michalis Spyrou254a48a2019-01-14 17:27:39 +0000112
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100113 uint32x4x4_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3},
114 {idx + 4, idx + 5, idx + 6, idx + 7},
115 {idx + 8, idx + 9, idx + 10, idx + 11},
116 {idx + 12, idx + 13, idx + 14, idx + 15}}};
117 if (axis != 0)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000118 {
119 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
120 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
121 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
122 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
123 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100124 uint32x4x4_t res = {
125 {vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]), vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
126 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]), vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])}};
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000127
128 return res;
129}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100130
131// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000132template <typename T>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100133inline typename std::enable_if<
134 std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
135 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type
136calculate_min(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100137{
138 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
139 return wrapper::vpmin(pmin, pmin);
140}
141
Luca Foschianiee939fb2020-01-28 10:38:07 +0000142// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
143template <typename T>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100144inline typename std::enable_if<
145 std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
146 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type
147calculate_min(T in)
Luca Foschianiee939fb2020-01-28 10:38:07 +0000148{
149 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
150 pmin = wrapper::vpmin(pmin, pmin);
151 pmin = wrapper::vpmin(pmin, pmin);
152 return wrapper::vpmin(pmin, pmin);
153}
154
Usama Arifa4a08ad2019-05-20 12:38:33 +0100155// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000156template <typename T>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100157inline typename std::enable_if<
158 std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
159 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type
160calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100161{
162 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
163 return wrapper::vpmax(pmax, pmax);
164}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100165
166// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000167template <typename T>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100168inline typename std::enable_if<
169 std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
170 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type
171calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100172{
173 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000174 pmax = wrapper::vpmax(pmax, pmax);
175 pmax = wrapper::vpmax(pmax, pmax);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100176 return wrapper::vpmax(pmax, pmax);
177}
178
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100179template <typename T>
180uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000181{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100182 uint32x4_t res_idx_mask{0};
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000183 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
184
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100185 if (op == ReductionOperation::ARG_IDX_MIN)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000186 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100187 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000188 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
189 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
190 }
191 else
192 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100193 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100194 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000195 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
196 }
197
198 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
199 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
200 pmin = wrapper::vpmin(pmin, pmin);
201 uint32_t res = wrapper::vgetlane(pmin, 0);
202
203 return (res - 0xFFFFFFFF);
204}
205
Luca Foschianiee939fb2020-01-28 10:38:07 +0000206template <typename T>
207uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000208{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100209 uint32x4x4_t res_idx_mask{{0}};
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000210 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100211 uint8x16_t mask_u8{0};
212 if (op == ReductionOperation::ARG_IDX_MIN)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000213 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100214 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000215 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
216 }
217 else
218 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100219 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000220 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
221 }
222
223 // Widen vectors
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100224 auto wide_u16_1 =
225 wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
226 auto wide_u16_2 =
227 wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
228 auto wide_u32_1 =
229 wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
230 auto wide_u32_2 =
231 wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
232 auto wide_u32_3 =
233 wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
234 auto wide_u32_4 =
235 wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000236 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
237 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
238 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
239 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
240 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
241 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
242 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
243 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
244
245 uint32_t res = 0xFFFFFFFF;
246 int iter = 0;
247 do
248 {
249 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
250 pmin = wrapper::vpmin(pmin, pmin);
251 res = std::min(wrapper::vgetlane(pmin, 0), res);
252 iter++;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100253 } while (iter < 4);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000254
255 return (res - 0xFFFFFFFF);
256}
Luca Foschianiee939fb2020-01-28 10:38:07 +0000257
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000258#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100259template <>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100260uint32x4x4_t
261calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000262{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100263 uint32x4x2_t mask{0};
264 uint16x8_t mask_u16{0};
265 if (op == ReductionOperation::ARG_IDX_MIN)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000266 {
267 mask_u16 = wrapper::vcgt(b, a);
268 }
269 else
270 {
271 mask_u16 = wrapper::vclt(b, a);
272 }
273 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
274 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100275 uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}};
276 if (axis != 0)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000277 {
278 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
279 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
280 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100281 uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
282 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0};
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000283
284 return res;
285}
286
Usama Arifa4a08ad2019-05-20 12:38:33 +0100287// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
288inline float16x4_t calculate_min(float16x8_t in)
289{
290 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
291 pmin = wrapper::vpmin(pmin, pmin);
292 return wrapper::vpmin(pmin, pmin);
293}
294// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
295inline float16x4_t calculate_max(float16x8_t in)
296{
297 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
298 pmax = wrapper::vpmax(pmax, pmax);
299 return wrapper::vpmax(pmax, pmax);
300}
301
Usama Arif0a5a57a2019-05-23 14:20:33 +0100302template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000303uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
304{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100305 uint32x4x2_t res_idx_mask{0};
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000306 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
307 uint16x8_t mask_u16;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100308 if (op == ReductionOperation::ARG_IDX_MIN)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000309 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100310 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000311 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
312 }
313 else
314 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100315 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000316 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
317 }
318
319 // Widen vectors
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100320 auto wide_u32_1 =
321 wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
322 auto wide_u32_2 =
323 wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000324 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
325 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
326 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
327 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
328
329 uint32_t res = 0xFFFFFFFF;
Michalis Spyrouc89998f2021-08-26 14:11:44 +0100330 uint32_t iter = 0;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000331 do
332 {
333 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
334 pmin = wrapper::vpmin(pmin, pmin);
335 res = std::min(wrapper::vgetlane(pmin, 0), res);
336 iter++;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100337 } while (iter < 2);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000338
339 return (res - 0xFFFFFFFF);
340}
341#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
342
Georgios Pinitasd9769582017-08-03 10:19:40 +0100343template <class F>
344class Reducer
345{
346public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000347 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100348 {
349 // Set out window
350 Window out_window(window);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100351 out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100352
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100353 f(window, out_window, input, output, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100354 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000355 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100356 {
357 // Set in window
358 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000359 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100360
361 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000362 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100363
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100364 f(in_window, out_window, input, output, 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100365 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000366 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100367 {
368 // Set in window
369 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000370 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100371
372 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000373 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100374
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100375 f(in_window, out_window, input, output, 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100376 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000377 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100378 {
379 // Set in/out window
380 Window in_window(window);
381 Window out_window(window);
382
383 in_window.set(3, Window::Dimension(0, 1, 1));
384 out_window.set(3, Window::Dimension(0, 1, 1));
385
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100386 f(in_window, out_window, input, output, 3, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100387 }
388};
389
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000390template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100391struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100392{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +0000393 /** SIMD vector tag type. */
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100394 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
395
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100396 inline void operator()(
397 const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100398 {
Manuel Bottini6a5eee72021-04-30 12:37:04 +0100399 const size_t input_dim_0 = in->info()->dimension(0);
400 const int window_step_x = 16 / sizeof(T);
401 const auto window_start_x = static_cast<int>(in_window.x().start());
402 const auto window_end_x = static_cast<int>(in_window.x().end());
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100403
Georgios Pinitas412b7892020-11-11 21:05:24 +0000404 Window in_win_no_pad = in_window;
405 in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100406
Georgios Pinitas412b7892020-11-11 21:05:24 +0000407 Iterator input(in, in_win_no_pad);
408 Iterator output(out, out_window);
409
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000410 execute_window_loop(
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100411 in_win_no_pad,
412 [&](const Coordinates &)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000413 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100414 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000415
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100416 auto init_res_value = static_cast<T>(0.f);
417 switch (op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100418 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100419 case ReductionOperation::ARG_IDX_MAX:
420 case ReductionOperation::ARG_IDX_MIN:
421 case ReductionOperation::MIN:
422 case ReductionOperation::MAX:
423 {
424 init_res_value = static_cast<T>(*input_ptr);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000425 break;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100426 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000427 case ReductionOperation::PROD:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100428 {
429 init_res_value = static_cast<T>(1.f);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000430 break;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100431 }
432 default:
433 break;
434 }
435 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
436 uint32x4x4_t vec_res_idx{{0}};
437
438 // Compute window_step_x elements per iteration
439 int x = window_start_x;
440 for (; x <= (window_end_x - window_step_x); x += window_step_x)
441 {
442 const auto vec_elements = wrapper::vloadq(input_ptr + x);
443 switch (op)
444 {
445 case ReductionOperation::SUM_SQUARE:
446 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
447 break;
448 case ReductionOperation::MEAN_SUM:
449 case ReductionOperation::SUM:
450 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
451 break;
452 case ReductionOperation::PROD:
453 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
454 break;
455 case ReductionOperation::ARG_IDX_MIN:
456 {
457 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
458 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value,
459 vec_res_idx, op, 0);
460 vec_res_value = temp_vec_res_value;
461 break;
462 }
463 case ReductionOperation::ARG_IDX_MAX:
464 {
465 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
466 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value,
467 vec_res_idx, op, 0);
468 vec_res_value = temp_vec_res_value;
469 break;
470 }
471 case ReductionOperation::MIN:
472 {
473 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
474 break;
475 }
476 case ReductionOperation::MAX:
477 {
478 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
479 break;
480 }
481 default:
482 ARM_COMPUTE_ERROR("Not supported");
483 }
484 }
485
486 switch (op)
487 {
488 case ReductionOperation::SUM:
489 case ReductionOperation::MEAN_SUM:
490 case ReductionOperation::SUM_SQUARE:
491 {
492#ifdef ARM_COMPUTE_DEBUG_ENABLED
493 auto res = static_cast<T>(0.f);
494 for (int i = 0; i < S; ++i)
495 {
496 res += wrapper::vgetlane(vec_res_value, i);
497 }
498#else // ARM_COMPUTE_DEBUG_ENABLED
499 auto carry_res =
500 wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
501 for (int i = 0; i < S / 4; ++i)
502 {
503 carry_res = wrapper::vpadd(carry_res, carry_res);
504 }
505 auto res = wrapper::vgetlane(carry_res, 0);
506#endif // ARM_COMPUTE_DEBUG_ENABLED
507 if (op == ReductionOperation::SUM_SQUARE)
508 {
509 // Compute left-over elements
510 for (; x < window_end_x; ++x)
511 {
512 res += (*(input_ptr + x)) * (*(input_ptr + x));
513 }
514 }
515 else
516 {
517 // Compute left-over elements
518 for (; x < window_end_x; ++x)
519 {
520 res += *(input_ptr + x);
521 }
522 }
523
524 if (op == ReductionOperation::MEAN_SUM)
525 {
526 res /= input_dim_0;
527 }
528
529 *(reinterpret_cast<T *>(output.ptr())) = res;
530 break;
531 }
532 case ReductionOperation::PROD:
533 {
534 auto carry_res =
535 wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
536 T res = 1;
537 for (int i = 0; i < S / 2; ++i)
538 {
539 res *= wrapper::vgetlane(carry_res, i);
540 }
541
542 // Compute left-over elements
543 for (; x < window_end_x; ++x)
544 {
545 res *= *(input_ptr + x);
546 }
547
548 *(reinterpret_cast<T *>(output.ptr())) = res;
549 break;
550 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000551 case ReductionOperation::ARG_IDX_MIN:
552 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100553 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
554 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
555
556 // Compute left-over elements
557 for (; x < window_end_x; ++x)
558 {
559 if (*(input_ptr + x) < res)
560 {
561 idx = x;
562 res = *(input_ptr + x);
563 }
564 }
565 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000566 break;
567 }
568 case ReductionOperation::ARG_IDX_MAX:
569 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100570 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
571 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
572
573 // Compute left-over elements
574 for (; x < window_end_x; ++x)
575 {
576 if (*(input_ptr + x) > res)
577 {
578 idx = x;
579 res = *(input_ptr + x);
580 }
581 }
582 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000583 break;
584 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100585 case ReductionOperation::MIN:
586 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100587 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
588
589 // Compute left-over elements
590 for (; x < window_end_x; ++x)
591 {
592 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
593 }
594 *(reinterpret_cast<T *>(output.ptr())) = res;
Usama Arifa4a08ad2019-05-20 12:38:33 +0100595 break;
596 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100597 case ReductionOperation::MAX:
598 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100599 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
600
601 // Compute left-over elements
602 for (; x < window_end_x; ++x)
603 {
604 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
605 }
606 *(reinterpret_cast<T *>(output.ptr())) = res;
Usama Arif28f0dd92019-05-20 13:44:34 +0100607 break;
608 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000609 default:
610 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100611 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100612 },
613 input, output);
giuros01154bc1c2019-03-26 17:44:40 +0000614 }
615};
616
Luca Foschianiee939fb2020-01-28 10:38:07 +0000617template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100618struct RedOpX_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100619{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100620 inline void operator()(
621 const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100622 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000623 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
624
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000625 const auto oq_info = out->info()->quantization_info().uniform();
626
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100627 const TensorInfo in_info = *(in->info());
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100628 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
629
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100630 const int window_step_x = 16 / sizeof(T);
631 const auto window_start_x = static_cast<int>(in_window.x().start());
632 const auto window_end_x = static_cast<int>(in_window.x().end());
633
Georgios Pinitas412b7892020-11-11 21:05:24 +0000634 Window in_win_no_pad = in_window;
635 in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
636
637 Iterator input(in, in_win_no_pad);
638 Iterator output(out, out_window);
639
Mohammed Suhail Munshiecaa10a2023-02-09 11:52:06 +0000640 const auto in_offset = static_cast<float>(iq_info.offset);
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000641 const float in_scale = iq_info.scale;
642
Mohammed Suhail Munshiecaa10a2023-02-09 11:52:06 +0000643 const auto out_offset = static_cast<float>(oq_info.offset);
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000644 const float out_scale = oq_info.scale;
645
Mohammed Suhail Munshiecaa10a2023-02-09 11:52:06 +0000646 const auto num_elements = static_cast<float>(in_info.dimension(0));
Omar Al Khatibe317baf2022-12-15 09:12:12 +0000647
648 const float A = in_scale / (out_scale * num_elements);
649 const float B = out_offset - (in_scale * in_offset) / (out_scale);
650
651 execute_window_loop(
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100652 in_win_no_pad,
653 [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100654 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100655 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100656
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100657 auto vec_res_value1 =
658 wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
659 auto vec_res_value2 =
660 wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
661 auto vec_res_value3 =
662 wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
663 auto vec_res_value4 =
664 wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
665
666 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
667 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
668 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
669 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
670
671 typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = {0};
672
673 if (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN ||
674 op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100675 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100676 vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{});
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100677 }
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100678
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100679 uint32x4x4_t vec_res_idx{{0}};
680 // Compute window_step_x elements per iteration
681 int x = window_start_x;
682 for (; x <= (window_end_x - window_step_x); x += window_step_x)
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +0100683 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100684 const auto vec_elements = wrapper::vloadq(input_ptr + x);
685 switch (op)
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100686 {
687 case ReductionOperation::SUM:
688 case ReductionOperation::MEAN_SUM:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100689 {
690 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
691 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
692
693 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
694 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
695 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
696 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
697
698 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
699 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
700 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
701 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100702 break;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100703 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100704 case ReductionOperation::PROD:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100705 {
706 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
707 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
708
709 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
710 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
711
712 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
713 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
714 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
715 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
716
717 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
718 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
719 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
720 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
721
722 //de-quantize vec_elements
723 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
724 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
725 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
726 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
727
728 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
729 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
730 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
731 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100732 break;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100733 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100734 case ReductionOperation::ARG_IDX_MIN:
735 {
736 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100737 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(
738 x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
739 vec_res_value = temp_vec_res_value;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100740 break;
741 }
742 case ReductionOperation::ARG_IDX_MAX:
743 {
744 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100745 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(
746 x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
747 vec_res_value = temp_vec_res_value;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100748 break;
749 }
750 case ReductionOperation::MIN:
751 {
752 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
753 break;
754 }
755 case ReductionOperation::MAX:
756 {
757 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
758 break;
759 }
760 default:
761 ARM_COMPUTE_ERROR("Not supported");
762 }
763 }
764
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100765 switch (op)
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100766 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100767 case ReductionOperation::ARG_IDX_MIN:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100768 {
769 auto idx =
770 calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
771 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
772
773 // Compute left-over elements
774 for (; x < window_end_x; ++x)
775 {
776 if (*(input_ptr + x) < res)
777 {
778 idx = x;
779 res = *(input_ptr + x);
780 }
781 }
782 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
783 break;
784 }
785 case ReductionOperation::ARG_IDX_MAX:
786 {
787 auto idx =
788 calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
789 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
790
791 // Compute left-over elements
792 for (; x < window_end_x; ++x)
793 {
794 if (*(input_ptr + x) > res)
795 {
796 idx = x;
797 res = *(input_ptr + x);
798 }
799 }
800 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
801 break;
802 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100803 case ReductionOperation::MIN:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100804 {
805 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
806
807 // Compute left-over elements
808 for (; x < window_end_x; ++x)
809 {
810 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
811 }
812 *(reinterpret_cast<T *>(output.ptr())) = res;
813 break;
814 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100815 case ReductionOperation::MAX:
816 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100817 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
818
819 // Compute left-over elements
820 for (; x < window_end_x; ++x)
821 {
822 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
823 }
824 *(reinterpret_cast<T *>(output.ptr())) = res;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100825 break;
826 }
827 case ReductionOperation::PROD:
828 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100829 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
830 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
831 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
832
833 float res = wrapper::vgetlane(carry_res, 0);
834 res *= wrapper::vgetlane(carry_res, 1);
835 res *= wrapper::vgetlane(carry_res, 2);
836 res *= wrapper::vgetlane(carry_res, 3);
837
838 // Compute left-over elements
839 for (; x < window_end_x; ++x)
840 {
841 //de-quantize input
842 if (std::is_same<T, uint8_t>::value)
843 {
844 res *= dequantize_qasymm8(*(input_ptr + x), iq_info);
845 }
846 else
847 {
848 res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info);
849 }
850 }
851
852 //re-quantize result
853 if (std::is_same<T, uint8_t>::value)
854 {
855 res = quantize_qasymm8(res, iq_info);
856 }
857 else
858 {
859 res = quantize_qasymm8_signed(res, iq_info);
860 }
861
862 *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
863 break;
864 }
865 case ReductionOperation::SUM:
866 case ReductionOperation::MEAN_SUM:
867 {
868 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
869 carry_res = wrapper::vadd(carry_res, vec_res_value3);
870 carry_res = wrapper::vadd(carry_res, vec_res_value4);
871
872 auto carry_paddition =
873 wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
874 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
875 auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
876
877 // Compute left-over elements
878 for (; x < window_end_x; ++x)
879 {
880 res += *(input_ptr + x);
881 }
882
883 if (op == ReductionOperation::MEAN_SUM)
884 {
885 const int32_t resFinal = A * (static_cast<float>(res)) + B;
886
887 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal);
888 }
889 else
890 {
891 // Subtract accumulated offsets
892 res -= (in_info.dimension(0) - 1) * iq_info.offset;
893 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
894 }
895
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100896 break;
897 }
898 default:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100899 ARM_COMPUTE_ERROR("Not supported");
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100900 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100901 },
902 input, output);
903 }
904};
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100905
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100906template <typename T, int S>
907struct RedOpYZW
908{
909 /** SIMD vector tag type. */
910 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
911 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
912
913 inline void operator()(const Window &in_window,
914 Window &out_window,
915 const ITensor *in,
916 ITensor *out,
917 int axis,
918 const ReductionOperation op)
919 {
920 const TensorInfo in_info = *(in->info());
921 const int window_step_x = 16 / sizeof(T);
922 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
923 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
924 // As it split over x-axis, need to set the correct spiltted window start and end.
925 const auto window_start_x = static_cast<int>(0);
926 const auto window_end_x = static_cast<int>(in_window.shape().x());
927
928 Window in_win_no_pad = in_window;
929 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
930 Window out_win_no_pad = out_window;
931 out_win_no_pad.set(Window::DimX,
932 Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
933
934 Iterator input(in, in_win_no_pad);
935 Iterator output(out, out_win_no_pad);
936
937 execute_window_loop(
938 in_win_no_pad,
939 [&](const Coordinates &)
940 {
941 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
942
943 // Compute window_step_x elements per iteration
944 int x = window_start_x;
945 for (; x <= (window_end_x - window_step_x); x += window_step_x)
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100946 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100947 neon_vector vec_res_value = {0};
948 switch (op)
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100949 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100950 case ReductionOperation::ARG_IDX_MAX:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100951 case ReductionOperation::ARG_IDX_MIN:
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100952 case ReductionOperation::MIN:
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100953 case ReductionOperation::MAX:
954 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100955 vec_res_value = wrapper::vloadq(input_ptr + x);
956 break;
957 }
958 case ReductionOperation::PROD:
959 {
960 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100961 break;
962 }
963 default:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100964 {
965 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
966 break;
967 }
968 }
969 uint32x4x4_t vec_res_idx{{0}};
970
971 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
972 {
973 const T *in_ptr =
974 reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
975 const auto vec_elements = wrapper::vloadq(in_ptr);
976 switch (op)
977 {
978 case ReductionOperation::SUM:
979 case ReductionOperation::MEAN_SUM:
980 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
981 break;
982 case ReductionOperation::SUM_SQUARE:
983 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
984 break;
985 case ReductionOperation::PROD:
986 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
987 break;
988 case ReductionOperation::ARG_IDX_MIN:
989 {
990 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
991 vec_res_idx =
992 calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
993 vec_res_value = temp_vec_res_value;
994 break;
995 }
996 case ReductionOperation::ARG_IDX_MAX:
997 {
998 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
999 vec_res_idx =
1000 calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1001 vec_res_value = temp_vec_res_value;
1002 break;
1003 }
1004 case ReductionOperation::MIN:
1005 {
1006 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1007 break;
1008 }
1009 case ReductionOperation::MAX:
1010 {
1011 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1012 break;
1013 }
1014 default:
1015 ARM_COMPUTE_ERROR("Not supported");
1016 }
1017 }
1018
1019 if (op == ReductionOperation::MEAN_SUM)
1020 {
1021 auto vec_width_inv =
1022 wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
1023 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
1024 }
1025
1026 if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1027 {
1028 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
1029#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1030 if (std::is_same<T, float16_t>::value)
1031 {
1032 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
1033 }
1034#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1035 }
1036 else
1037 {
1038 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001039 }
1040 }
1041
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001042 // Compute left-over elements
1043 for (; x < window_end_x; ++x)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001044 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001045 auto res_value = 0.f;
1046 switch (op)
1047 {
1048 case ReductionOperation::ARG_IDX_MAX:
1049 case ReductionOperation::ARG_IDX_MIN:
1050 case ReductionOperation::MIN:
1051 case ReductionOperation::MAX:
1052 {
1053 res_value = *(input_ptr + x);
1054 break;
1055 }
1056 case ReductionOperation::PROD:
1057 {
1058 res_value = static_cast<T>(1.f);
1059 break;
1060 }
1061 default:
1062 {
1063 res_value = static_cast<T>(0.f);
1064 break;
1065 }
1066 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001067
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001068 uint32_t res_idx = 0;
1069 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1070 {
1071 const T *in_ptr =
1072 reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
1073
1074 switch (op)
1075 {
1076 case ReductionOperation::SUM:
1077 case ReductionOperation::MEAN_SUM:
1078 res_value += *in_ptr;
1079 break;
1080 case ReductionOperation::SUM_SQUARE:
1081 res_value += *in_ptr * *in_ptr;
1082 break;
1083 case ReductionOperation::PROD:
1084 res_value *= *in_ptr;
1085 break;
1086 case ReductionOperation::ARG_IDX_MIN:
1087 {
1088 if (*in_ptr < res_value)
1089 {
1090 res_value = *in_ptr;
1091 res_idx = dim;
1092 }
1093 break;
1094 }
1095 case ReductionOperation::ARG_IDX_MAX:
1096 {
1097 if (*in_ptr > res_value)
1098 {
1099 res_value = *in_ptr;
1100 res_idx = dim;
1101 }
1102 break;
1103 }
1104 case ReductionOperation::MIN:
1105 {
1106 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1107 break;
1108 }
1109 case ReductionOperation::MAX:
1110 {
1111 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1112 break;
1113 }
1114 default:
1115 ARM_COMPUTE_ERROR("Not supported");
1116 }
1117 }
1118
1119 if (op == ReductionOperation::MEAN_SUM)
1120 {
1121 res_value /= in_info.dimension(axis);
1122 }
1123
1124 if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1125 {
1126 *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
1127 }
1128 else
1129 {
1130 *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
1131 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001132 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001133 },
1134 input, output);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001135 }
1136};
1137
1138template <typename T, int S, int axis, ReductionOperation op>
1139struct RedOpYZW_complex
1140{
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +00001141 /** SIMD vector tag type. */
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001142 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
1143 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
1144
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001145 inline void operator()(
1146 const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001147 {
1148 ARM_COMPUTE_ERROR_ON(axis != 2);
Georgios Pinitas412b7892020-11-11 21:05:24 +00001149 ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001150
Georgios Pinitas412b7892020-11-11 21:05:24 +00001151 const TensorInfo in_info = *(in->info());
1152 const size_t stride_z = in_info.strides_in_bytes()[axis];
1153 const int window_step_x = 16 / sizeof(T);
1154 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
1155 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
1156 // As it split over x-axis, need to set the correct spiltted window start and end.
1157 const auto window_start_x = static_cast<int>(0);
1158 const auto window_end_x = static_cast<int>(in_window.shape().x());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001159
Georgios Pinitas412b7892020-11-11 21:05:24 +00001160 Window in_win_no_pad = in_window;
1161 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1162 Window out_win_no_pad = out_window;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001163 out_win_no_pad.set(Window::DimX,
1164 Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001165
Georgios Pinitas412b7892020-11-11 21:05:24 +00001166 Iterator input(in, in_win_no_pad);
1167 Iterator output(out, out_win_no_pad);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001168
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001169 execute_window_loop(
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001170 in_win_no_pad,
1171 [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001172 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001173 // Compute window_step_x elements per iteration
1174 int x = window_start_x;
1175 for (; x <= (window_end_x - window_step_x); x += window_step_x)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001176 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001177 neon_vector vec_res_value_0 = {0};
1178 neon_vector vec_res_value_1 = {0};
Georgios Pinitas412b7892020-11-11 21:05:24 +00001179
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001180 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1181 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001182
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001183 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1184 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1185 {
1186 T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1187 T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
1188
1189 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
1190 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
1191
1192 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
1193 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
1194 }
1195
1196 wrapper::vstore(out_ptr, vec_res_value_0);
1197 wrapper::vstore(out_ptr + 4, vec_res_value_1);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001198 }
1199
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001200 // Compute left-over elements
1201 for (; x < window_end_x; ++x)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001202 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001203 auto res_value_0 = 0.f;
1204 auto res_value_1 = 0.f;
1205
1206 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1207 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1208 {
1209 T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1210 res_value_0 += *in_ptr;
1211 res_value_1 += *(in_ptr + 1);
1212 }
1213 *out_ptr = res_value_0;
1214 *(out_ptr + 1) = res_value_1;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001215 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001216 },
1217 input, output);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001218 }
1219};
1220
1221template <typename T>
1222struct RedOpYZW_quantized
1223{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001224 inline void operator()(const Window &in_window,
1225 Window &out_window,
1226 const ITensor *in,
1227 ITensor *out,
1228 int axis,
1229 const ReductionOperation op)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001230 {
Georgios Pinitas412b7892020-11-11 21:05:24 +00001231 const TensorInfo in_info = *(in->info());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001232 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001233 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001234
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001235 const auto oq_info = out->info()->quantization_info().uniform();
1236
Georgios Pinitas412b7892020-11-11 21:05:24 +00001237 const int window_step_x = 16 / sizeof(T);
1238 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
1239 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
1240 // As it split over x-axis, need to set the correct spiltted window start and end.
1241 const auto window_start_x = static_cast<int>(0);
1242 const auto window_end_x = static_cast<int>(in_window.shape().x());
1243
1244 Window in_win_no_pad = in_window;
1245 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1246 Window out_win_no_pad = out_window;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001247 out_win_no_pad.set(Window::DimX,
1248 Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
Georgios Pinitas412b7892020-11-11 21:05:24 +00001249
1250 Iterator input(in, in_win_no_pad);
1251 Iterator output(out, out_win_no_pad);
1252
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001253 using vector_type =
1254 typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type;
Pablo Marquez Telloc4c595a2021-05-04 17:23:09 +01001255 using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type;
1256
1257 vector_type vec_res_value1{};
1258 vector_type vec_res_value2{};
1259 vector_type vec_res_value3{};
1260 vector_type vec_res_value4{};
1261
1262 vector_type_f vec_res_value1_f{};
1263 vector_type_f vec_res_value2_f{};
1264 vector_type_f vec_res_value3_f{};
1265 vector_type_f vec_res_value4_f{};
1266
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001267 const float in_offset = static_cast<float>(iq_info.offset);
1268 const float in_scale = iq_info.scale;
1269
1270 const float out_offset = static_cast<float>(oq_info.offset);
1271 const float out_scale = oq_info.scale;
1272
1273 const float num_elements = static_cast<float>(in_info.dimension(axis));
1274
1275 const float A = in_scale / (out_scale * num_elements);
1276 const float B = out_offset - (in_scale * in_offset) / (out_scale);
1277
1278 const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{});
1279 const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{});
1280
1281 execute_window_loop(
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001282 in_win_no_pad,
1283 [&](const Coordinates &)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001284 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001285 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001286
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001287 // Compute window_step_x elements per iteration
1288 int x = window_start_x;
1289 for (; x <= (window_end_x - window_step_x); x += window_step_x)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001290 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001291 uint32x4x4_t vec_res_idx{{0}};
1292 vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1293 vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1294 vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1295 vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1296
1297 vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1298 vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1299 vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1300 vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1301
1302 auto vec_res_value = wrapper::vloadq(input_ptr + x);
1303
1304 for (unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001305 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001306 const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim;
1307 const auto vec_elements = wrapper::vloadq(in_ptr);
1308 switch (op)
1309 {
1310 case ReductionOperation::SUM:
1311 case ReductionOperation::MEAN_SUM:
1312 {
1313 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1314 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1315
1316 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1317 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1318 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1319 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1320
1321 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
1322 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
1323 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
1324 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
1325 break;
1326 }
1327 case ReductionOperation::PROD:
1328 {
1329 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset),
1330 wrapper::traits::vector_128_tag{});
1331 const auto scale32x4f_4 =
1332 wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
1333
1334 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1335 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1336
1337 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1338 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1339 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1340 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1341
1342 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
1343 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
1344 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
1345 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
1346
1347 //de-quantize vec_elements
1348 temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
1349 temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
1350 temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
1351 temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
1352
1353 vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
1354 vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
1355 vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
1356 vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
1357 break;
1358 }
1359 case ReductionOperation::ARG_IDX_MIN:
1360 {
1361 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1362 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value,
1363 vec_res_idx, op, axis);
1364 vec_res_value = temp_vec_res_value;
1365 break;
1366 }
1367 case ReductionOperation::ARG_IDX_MAX:
1368 {
1369 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1370 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value,
1371 vec_res_idx, op, axis);
1372 vec_res_value = temp_vec_res_value;
1373 break;
1374 }
1375 case ReductionOperation::MIN:
1376 {
1377 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1378 break;
1379 }
1380 case ReductionOperation::MAX:
1381 {
1382 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1383 break;
1384 }
1385 default:
1386 ARM_COMPUTE_ERROR("Not supported");
1387 }
1388 }
1389
1390 switch (op)
1391 {
1392 case ReductionOperation::ARG_IDX_MIN:
1393 case ReductionOperation::ARG_IDX_MAX:
1394 {
1395 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]);
1396 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]);
1397 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]);
1398 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12,
1399 vec_res_idx.val[3]);
1400 break;
1401 }
1402 case ReductionOperation::MIN:
1403 case ReductionOperation::MAX:
1404 {
1405 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value);
1406 break;
1407 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001408 case ReductionOperation::SUM:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001409 {
1410 // Subtract offsets
1411 auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1412
1413 auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1414 auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1415 auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1416 auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
1417
1418 vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1419 vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1420 vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1421 vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1422
1423 const auto temp16x8t_1 =
1424 wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1425 const auto temp16x8t_2 =
1426 wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
1427
1428 combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x);
1429 break;
1430 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001431 case ReductionOperation::MEAN_SUM:
1432 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001433 vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A);
1434 vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A);
1435 vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A);
1436 vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001437
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001438#ifdef __aarch64__
1439 vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f);
1440 vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f);
1441 vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f);
1442 vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f);
1443#else // defined(__aarch64__)
1444 vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f);
1445 vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f);
1446 vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f);
1447 vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f);
1448#endif // __aarch64__
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001449
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001450 const auto temp16x8t_1 =
1451 wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1452 const auto temp16x8t_2 =
1453 wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1454 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1455
1456 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001457 break;
1458 }
1459 case ReductionOperation::PROD:
1460 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001461 const auto offset32x4f_4 =
1462 wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1463 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001464
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001465 //re-quantize
1466 vec_res_value1_f =
1467 wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1468 vec_res_value2_f =
1469 wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1470 vec_res_value3_f =
1471 wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1472 vec_res_value4_f =
1473 wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001474
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001475 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1476 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1477 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1478 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001479
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001480 const auto temp16x8t_1 =
1481 wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1482 const auto temp16x8t_2 =
1483 wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1484 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001485
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001486 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001487 break;
1488 }
1489 default:
1490 ARM_COMPUTE_ERROR("Not supported");
1491 }
1492 }
1493
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001494 // Compute left-over elements
1495 for (; x < window_end_x; ++x)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001496 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001497 float res_value = 0.f;
1498 int32_t res_value_q = 0;
1499
1500 switch (op)
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001501 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001502 case ReductionOperation::ARG_IDX_MAX:
1503 case ReductionOperation::ARG_IDX_MIN:
1504 case ReductionOperation::MIN:
1505 case ReductionOperation::MAX:
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001506 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001507 res_value = *(input_ptr + x);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001508 break;
1509 }
1510 case ReductionOperation::PROD:
1511 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001512 res_value = static_cast<T>(1.0f);
1513 break;
1514 }
1515 default:
1516 {
1517 res_value = static_cast<T>(0.0f);
1518 break;
1519 }
1520 }
1521 uint32_t res_idx = 0;
1522
1523 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1524 {
1525 const T *in_ptr =
1526 reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim);
1527 switch (op)
1528 {
1529 case ReductionOperation::SUM:
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001530 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001531 res_value += *in_ptr;
1532 break;
1533 }
1534 case ReductionOperation::MEAN_SUM:
1535 {
1536 res_value_q += *in_ptr;
1537 break;
1538 }
1539 case ReductionOperation::SUM_SQUARE:
1540 {
1541 res_value += *in_ptr * *in_ptr;
1542 break;
1543 }
1544 case ReductionOperation::PROD:
1545 {
1546 //de-quantize input
1547 if (std::is_same<T, uint8_t>::value)
1548 {
1549 res_value *= dequantize_qasymm8(*in_ptr, iq_info);
1550 }
1551 else
1552 {
1553 res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info);
1554 }
1555 break;
1556 }
1557 case ReductionOperation::ARG_IDX_MIN:
1558 {
1559 if (*in_ptr < res_value)
1560 {
1561 res_value = *in_ptr;
1562 res_idx = dim;
1563 }
1564 break;
1565 }
1566 case ReductionOperation::ARG_IDX_MAX:
1567 {
1568 if (*in_ptr > res_value)
1569 {
1570 res_value = *in_ptr;
1571 res_idx = dim;
1572 }
1573 break;
1574 }
1575 case ReductionOperation::MIN:
1576 {
1577 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1578 break;
1579 }
1580 case ReductionOperation::MAX:
1581 {
1582 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1583 break;
1584 }
1585 default:
1586 ARM_COMPUTE_ERROR("Not supported");
1587 }
1588 }
1589
1590 switch (op)
1591 {
1592 case ReductionOperation::MEAN_SUM:
1593 {
1594 // Apply previously calculated coefficients (with rounding on aarch64)
1595#ifdef __aarch64__
1596 const int32_t res =
1597 arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B);
1598#else // defined(__aarch64__)
1599 const int32_t res = A * (static_cast<float>(res_value_q)) + B;
1600#endif // __aarch64__
1601 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res);
1602 break;
1603 }
1604 case ReductionOperation::SUM:
1605 {
1606 // Subtract accumulated offsets
1607 res_value -= (in_info.dimension(axis) - 1) * iq_info.offset;
1608 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
1609 break;
1610 }
1611 case ReductionOperation::PROD:
1612 {
1613 //re-quantize result
1614 T res = 0;
1615 if (std::is_same<T, uint8_t>::value)
1616 {
1617 res = quantize_qasymm8(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001618 }
1619 else
1620 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001621 res = quantize_qasymm8_signed(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001622 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001623 *(reinterpret_cast<T *>(output.ptr() + x)) = res;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001624 break;
1625 }
1626 case ReductionOperation::ARG_IDX_MIN:
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001627 case ReductionOperation::ARG_IDX_MAX:
1628 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001629 *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001630 break;
1631 }
1632 default:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001633 *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001634 }
1635 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001636 },
1637 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001638 }
1639};
1640
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001641void reduce_op(
1642 const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001643{
giuros01154bc1c2019-03-26 17:44:40 +00001644 const bool is_complex = (input->info()->num_channels() == 2);
1645
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001646 if (is_complex)
giuros01154bc1c2019-03-26 17:44:40 +00001647 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001648 switch (axis)
giuros01154bc1c2019-03-26 17:44:40 +00001649 {
1650 case 2:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001651 switch (input->info()->data_type())
giuros01154bc1c2019-03-26 17:44:40 +00001652 {
1653 case DataType::F32:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001654 switch (op)
giuros01154bc1c2019-03-26 17:44:40 +00001655 {
1656 case ReductionOperation::SUM:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001657 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(
1658 window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(),
1659 op);
giuros01154bc1c2019-03-26 17:44:40 +00001660 default:
1661 ARM_COMPUTE_ERROR("Not supported");
1662 }
1663 default:
1664 ARM_COMPUTE_ERROR("Not supported");
1665 }
1666 default:
1667 ARM_COMPUTE_ERROR("Not supported");
1668 }
Manuel Bottini6a5eee72021-04-30 12:37:04 +01001669 return;
giuros01154bc1c2019-03-26 17:44:40 +00001670 }
1671
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001672 switch (axis)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001673 {
1674 case 0:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001675 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001676 switch (input->info()->data_type())
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001677 {
1678 case DataType::QASYMM8:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001679 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001680 return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output,
1681 RedOpX_quantized<uint8_t>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001682 }
Luca Foschianiee939fb2020-01-28 10:38:07 +00001683 case DataType::QASYMM8_SIGNED:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001684 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001685 return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(),
1686 op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001687 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001688#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1689 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001690 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001691#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1692 case DataType::F32:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001693 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001694 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001695 }
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001696 case DataType::S32:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001697 {
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001698 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001699 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001700 default:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001701 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001702 ARM_COMPUTE_ERROR("Not supported");
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001703 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001704 }
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001705 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001706 case 1:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001707 switch (input->info()->data_type())
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001708 {
1709 case DataType::QASYMM8:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001710 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001711 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output,
1712 RedOpYZW_quantized<uint8_t>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001713 }
Luca Foschianiee939fb2020-01-28 10:38:07 +00001714 case DataType::QASYMM8_SIGNED:
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001715 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001716 return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output,
1717 RedOpYZW_quantized<int8_t>(), op);
Omar Al Khatibe317baf2022-12-15 09:12:12 +00001718 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001719#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1720 case DataType::F16:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001721 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(),
1722 op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001723#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1724 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001725 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001726 case DataType::S32:
1727 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001728 default:
1729 ARM_COMPUTE_ERROR("Not supported");
1730 }
1731 case 2:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001732 switch (input->info()->data_type())
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001733 {
1734 case DataType::QASYMM8:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001735 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output,
1736 RedOpYZW_quantized<uint8_t>(), op);
Luca Foschianiee939fb2020-01-28 10:38:07 +00001737 case DataType::QASYMM8_SIGNED:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001738 return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output,
1739 RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001740#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1741 case DataType::F16:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001742 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(),
1743 op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001744#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1745 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001746 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001747 case DataType::S32:
1748 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001749 default:
1750 ARM_COMPUTE_ERROR("Not supported");
1751 }
1752 case 3:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001753 switch (input->info()->data_type())
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001754 {
1755 case DataType::QASYMM8:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001756 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output,
1757 RedOpYZW_quantized<uint8_t>(), op);
Luca Foschianiee939fb2020-01-28 10:38:07 +00001758 case DataType::QASYMM8_SIGNED:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001759 return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output,
1760 RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001761#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1762 case DataType::F16:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001763 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(),
1764 op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001765#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1766 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001767 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001768 case DataType::S32:
1769 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001770 default:
1771 ARM_COMPUTE_ERROR("Not supported");
1772 }
1773 default:
1774 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1775 }
1776}
John Richardson73d4aef2018-05-08 14:34:33 +01001777
1778Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1779{
1780 ARM_COMPUTE_UNUSED(op);
1781
1782 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001783 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001784
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001785 if (input->num_channels() == 1)
giuros01154bc1c2019-03-26 17:44:40 +00001786 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001787 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8,
1788 DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001789 }
1790 else
1791 {
1792 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1793 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1794 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1795 }
John Richardson73d4aef2018-05-08 14:34:33 +01001796
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001797 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions,
1798 "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001799 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001800
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001801 if (output->total_size() != 0)
John Richardson73d4aef2018-05-08 14:34:33 +01001802 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001803 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001804 if (!is_arg_min_max)
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001805 {
1806 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001807 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001808 }
1809 else
1810 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001811 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001812 }
John Richardson73d4aef2018-05-08 14:34:33 +01001813
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001814 const TensorShape output_shape =
1815 arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
1816 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
John Richardson73d4aef2018-05-08 14:34:33 +01001817 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1818 }
1819
1820 return Status{};
1821}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001822} // namespace
1823
1824NEReductionOperationKernel::NEReductionOperationKernel()
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001825 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001826{
1827}
1828
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001829void NEReductionOperationKernel::configure(const ITensor *input,
1830 ITensor *output,
1831 unsigned int axis,
1832 ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001833{
1834 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001835
John Richardson73d4aef2018-05-08 14:34:33 +01001836 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001837
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001838 _input = input;
1839 _output = output;
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001840 _op = op;
1841 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001842
1843 // Configure kernel window
Georgios Pinitas412b7892020-11-11 21:05:24 +00001844 Window win = calculate_max_window(*input->info(), Steps());
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001845 INEKernel::configure(win);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001846
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001847 // Calculate output shape and set if empty
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001848 const TensorShape output_shape =
1849 arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001850 // Output auto initialization if not yet initialized
1851 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1852 DataType output_data_type = is_arg_min_max ? DataType::S32 : input->info()->data_type();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001853 auto_init_if_empty(*output->info(), input->info()
1854 ->clone()
1855 ->set_tensor_shape(output_shape)
1856 .set_data_type(output_data_type)
1857 .reset_padding()
1858 .set_is_resizable(true));
John Richardson73d4aef2018-05-08 14:34:33 +01001859}
1860
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001861Status NEReductionOperationKernel::validate(const ITensorInfo *input,
1862 const ITensorInfo *output,
1863 unsigned int axis,
1864 ReductionOperation op)
John Richardson73d4aef2018-05-08 14:34:33 +01001865{
1866 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
John Richardson73d4aef2018-05-08 14:34:33 +01001867
1868 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001869}
1870
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001871void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001872{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001873 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001874 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1875 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1876
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001877 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001878}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001879} // namespace arm_compute