blob: 4e63dd95aa051657bdfdc2a53c49dadb2d40ae76 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEReductionOperationKernel.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010025
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
John Richardson73d4aef2018-05-08 14:34:33 +010030#include "arm_compute/core/TensorInfo.h"
Luca Foschianiee939fb2020-01-28 10:38:07 +000031#include "arm_compute/core/Utils.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010032#include "arm_compute/core/Validate.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000033#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/CPP/Validate.h"
Michalis Spyrouebcebf12020-10-21 00:04:14 +010035#include "src/core/NEON/INEKernel.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010036#include "src/core/NEON/NEMath.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010037#include "src/core/helpers/AutoConfiguration.h"
38#include "src/core/helpers/WindowHelpers.h"
39#include "support/SaturateCast.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010040
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010041#include "src/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010042#include <arm_neon.h>
43
Michalis Spyroubcf8a962018-10-12 10:51:31 +010044namespace arm_compute
45{
Georgios Pinitasd9769582017-08-03 10:19:40 +010046namespace
47{
Luca Foschianiee939fb2020-01-28 10:38:07 +000048// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
49template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +010050void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
Luca Foschianiee939fb2020-01-28 10:38:07 +000051{
52 if(std::is_same<T, uint8_t>::value)
53 {
54 auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010055 wrapper::vstore(output.ptr() + offset, res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000056 }
57 else
58 {
59 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010060 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000061 }
62}
63
Michalis Spyroub9626ab2019-05-13 17:41:01 +010064template <typename T>
65uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000066{
67 uint32x4_t mask{ 0 };
68 if(op == ReductionOperation::ARG_IDX_MIN)
69 {
70 mask = wrapper::vcgt(b, a);
71 }
72 else
73 {
74 mask = wrapper::vclt(b, a);
75 }
76
77 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
78 if(axis != 0)
79 {
80 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
81 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000082 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000083
84 return res;
85}
86
Luca Foschianiee939fb2020-01-28 10:38:07 +000087template <typename T>
88uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000089{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000090 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000091 uint8x16_t mask_u8{ 0 };
92 if(op == ReductionOperation::ARG_IDX_MIN)
93 {
94 mask_u8 = wrapper::vcgt(b, a);
95 }
96 else
97 {
98 mask_u8 = wrapper::vclt(b, a);
99 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +0000100 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
101 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
102 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
103 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
104 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
105 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
106
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000107 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
108 { idx + 4, idx + 5, idx + 6, idx + 7 },
109 { idx + 8, idx + 9, idx + 10, idx + 11 },
110 { idx + 12, idx + 13, idx + 14, idx + 15 }
111 }
112 };
113 if(axis != 0)
114 {
115 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
116 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
117 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
118 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
119 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000120 uint32x4x4_t res =
121 {
122 {
123 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
124 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
125 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
126 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
127 }
128 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000129
130 return res;
131}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100132
133// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000134template <typename T>
135inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
136 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
137 calculate_min(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100138{
139 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
140 return wrapper::vpmin(pmin, pmin);
141}
142
Luca Foschianiee939fb2020-01-28 10:38:07 +0000143// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
144template <typename T>
145inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
146 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
147 calculate_min(T in)
148{
149 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
150 pmin = wrapper::vpmin(pmin, pmin);
151 pmin = wrapper::vpmin(pmin, pmin);
152 return wrapper::vpmin(pmin, pmin);
153}
154
Usama Arifa4a08ad2019-05-20 12:38:33 +0100155// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000156template <typename T>
157inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
158 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
159 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100160{
161 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
162 return wrapper::vpmax(pmax, pmax);
163}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100164
165// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000166template <typename T>
167inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
168 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
169 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100170{
171 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000172 pmax = wrapper::vpmax(pmax, pmax);
173 pmax = wrapper::vpmax(pmax, pmax);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100174 return wrapper::vpmax(pmax, pmax);
175}
176
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100177template <typename T>
178uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000179{
180 uint32x4_t res_idx_mask{ 0 };
181 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
182
183 if(op == ReductionOperation::ARG_IDX_MIN)
184 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100185 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000186 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
187 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
188 }
189 else
190 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100191 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100192 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000193 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
194 }
195
196 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
197 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
198 pmin = wrapper::vpmin(pmin, pmin);
199 uint32_t res = wrapper::vgetlane(pmin, 0);
200
201 return (res - 0xFFFFFFFF);
202}
203
Luca Foschianiee939fb2020-01-28 10:38:07 +0000204template <typename T>
205uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000206{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000207 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000208 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
209 uint8x16_t mask_u8{ 0 };
210 if(op == ReductionOperation::ARG_IDX_MIN)
211 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100212 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000213 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
214 }
215 else
216 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100217 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000218 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
219 }
220
221 // Widen vectors
222 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
223 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
224 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
225 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
226 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
227 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
228 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
229 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
230 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
231 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
232 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
233 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
234 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
235 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
236
237 uint32_t res = 0xFFFFFFFF;
238 int iter = 0;
239 do
240 {
241 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
242 pmin = wrapper::vpmin(pmin, pmin);
243 res = std::min(wrapper::vgetlane(pmin, 0), res);
244 iter++;
245 }
246 while(iter < 4);
247
248 return (res - 0xFFFFFFFF);
249}
Luca Foschianiee939fb2020-01-28 10:38:07 +0000250
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000251#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100252template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000253uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
254{
255 uint32x4x2_t mask{ 0 };
256 uint16x8_t mask_u16{ 0 };
257 if(op == ReductionOperation::ARG_IDX_MIN)
258 {
259 mask_u16 = wrapper::vcgt(b, a);
260 }
261 else
262 {
263 mask_u16 = wrapper::vclt(b, a);
264 }
265 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
266 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
267 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
268 { idx + 4, idx + 5, idx + 6, idx + 7 }
269 }
270 };
271 if(axis != 0)
272 {
273 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
274 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
275 }
276 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
277 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
278 0, 0
279 };
280
281 return res;
282}
283
Usama Arifa4a08ad2019-05-20 12:38:33 +0100284// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
285inline float16x4_t calculate_min(float16x8_t in)
286{
287 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
288 pmin = wrapper::vpmin(pmin, pmin);
289 return wrapper::vpmin(pmin, pmin);
290}
291// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
292inline float16x4_t calculate_max(float16x8_t in)
293{
294 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
295 pmax = wrapper::vpmax(pmax, pmax);
296 return wrapper::vpmax(pmax, pmax);
297}
298
Usama Arif0a5a57a2019-05-23 14:20:33 +0100299template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000300uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
301{
302 uint32x4x2_t res_idx_mask{ 0 };
303 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
304 uint16x8_t mask_u16;
305 if(op == ReductionOperation::ARG_IDX_MIN)
306 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100307 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000308 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
309 }
310 else
311 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100312 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000313 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
314 }
315
316 // Widen vectors
317 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
318 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
319 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
320 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
321 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
322 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
323
324 uint32_t res = 0xFFFFFFFF;
325 int iter = 0;
326 do
327 {
328 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
329 pmin = wrapper::vpmin(pmin, pmin);
330 res = std::min(wrapper::vgetlane(pmin, 0), res);
331 iter++;
332 }
333 while(iter < 2);
334
335 return (res - 0xFFFFFFFF);
336}
337#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
338
Georgios Pinitasd9769582017-08-03 10:19:40 +0100339template <class F>
340class Reducer
341{
342public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000343 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100344 {
345 // Set out window
346 Window out_window(window);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100347 out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100348
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100349 f(window, out_window, input, output, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100350 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000351 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100352 {
353 // Set in window
354 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000355 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100356
357 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000358 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100359
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100360 f(in_window, out_window, input, output, 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100361 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000362 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100363 {
364 // Set in window
365 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000366 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100367
368 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000369 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100370
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100371 f(in_window, out_window, input, output, 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100372 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000373 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100374 {
375 // Set in/out window
376 Window in_window(window);
377 Window out_window(window);
378
379 in_window.set(3, Window::Dimension(0, 1, 1));
380 out_window.set(3, Window::Dimension(0, 1, 1));
381
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100382 f(in_window, out_window, input, output, 3, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100383 }
384};
385
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000386template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100387struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100388{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100389 /** NEON vector tag type. */
390 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
391
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100392 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100393 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100394 const TensorInfo in_info = *(in->info());
395
396 Iterator input(in, in_window);
397 Iterator output(out, out_window);
398 const int window_step_x = 16 / sizeof(T);
399 const auto window_start_x = static_cast<int>(in_window.x().start());
400 const auto window_end_x = static_cast<int>(in_window.x().end());
401
402 execute_window_loop(in_window, [&](const Coordinates &)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000403 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100404 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
Georgios Pinitasd9769582017-08-03 10:19:40 +0100405
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100406 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100407 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000408 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100409 case ReductionOperation::ARG_IDX_MAX:
410 case ReductionOperation::ARG_IDX_MIN:
411 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100412 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100413 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100414 init_res_value = static_cast<T>(*input_ptr);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100415 break;
416 }
417 case ReductionOperation::PROD:
418 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100419 init_res_value = static_cast<T>(1.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100420 break;
421 }
422 default:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100423 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000424 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100425 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000426 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000427
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100428 // Compute window_step_x elements per iteration
429 int x = window_start_x;
430 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100431 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100432 const auto vec_elements = wrapper::vloadq(input_ptr + x);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000433 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100434 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000435 case ReductionOperation::SUM_SQUARE:
436 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
437 break;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100438 case ReductionOperation::MEAN_SUM:
439 case ReductionOperation::SUM:
440 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
441 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000442 case ReductionOperation::PROD:
443 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
444 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000445 case ReductionOperation::ARG_IDX_MIN:
446 {
447 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100448 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000449 vec_res_value = temp_vec_res_value;
450 break;
451 }
452 case ReductionOperation::ARG_IDX_MAX:
453 {
454 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100455 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000456 vec_res_value = temp_vec_res_value;
457 break;
458 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100459 case ReductionOperation::MIN:
460 {
461 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
462 break;
463 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100464 case ReductionOperation::MAX:
465 {
466 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
467 break;
468 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000469 default:
470 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100471 }
472 }
473
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100474 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100475 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100476 case ReductionOperation::SUM:
477 case ReductionOperation::MEAN_SUM:
478 case ReductionOperation::SUM_SQUARE:
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100479 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100480 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
481 for(int i = 0; i < S / 4; ++i)
482 {
483 carry_res = wrapper::vpadd(carry_res, carry_res);
484 }
485 auto res = wrapper::vgetlane(carry_res, 0);
486
487 if(op == ReductionOperation::SUM_SQUARE)
488 {
489 // Compute left-over elements
490 for(; x < window_end_x; ++x)
491 {
492 res += (*(input_ptr + x)) * (*(input_ptr + x));
493 }
494 }
495 else
496 {
497 // Compute left-over elements
498 for(; x < window_end_x; ++x)
499 {
500 res += *(input_ptr + x);
501 }
502 }
503
504 if(op == ReductionOperation::MEAN_SUM)
505 {
506 res /= in_info.dimension(0);
507 }
508
509 *(reinterpret_cast<T *>(output.ptr())) = res;
510 break;
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100511 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100512 case ReductionOperation::PROD:
giuros01154bc1c2019-03-26 17:44:40 +0000513 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100514 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
515 T res = 1;
516 for(int i = 0; i < S / 2; ++i)
517 {
518 res *= wrapper::vgetlane(carry_res, i);
519 }
giuros01154bc1c2019-03-26 17:44:40 +0000520
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100521 // Compute left-over elements
522 for(; x < window_end_x; ++x)
523 {
524 res *= *(input_ptr + x);
525 }
526
527 *(reinterpret_cast<T *>(output.ptr())) = res;
528 break;
529 }
530 case ReductionOperation::ARG_IDX_MIN:
giuros01154bc1c2019-03-26 17:44:40 +0000531 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100532 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
533 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
534
535 // Compute left-over elements
536 for(; x < window_end_x; ++x)
537 {
538 if(*(input_ptr + x) < res)
539 {
540 idx = x;
541 res = *(input_ptr + x);
542 }
543 }
544 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
545 break;
giuros01154bc1c2019-03-26 17:44:40 +0000546 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100547 case ReductionOperation::ARG_IDX_MAX:
548 {
549 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
550 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
551
552 // Compute left-over elements
553 for(; x < window_end_x; ++x)
554 {
555 if(*(input_ptr + x) > res)
556 {
557 idx = x;
558 res = *(input_ptr + x);
559 }
560 }
561 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
562 break;
563 }
564 case ReductionOperation::MIN:
565 {
566 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
567
568 // Compute left-over elements
569 for(; x < window_end_x; ++x)
570 {
571 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
572 }
573 *(reinterpret_cast<T *>(output.ptr())) = res;
574 break;
575 }
576 case ReductionOperation::MAX:
577 {
578 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
579
580 // Compute left-over elements
581 for(; x < window_end_x; ++x)
582 {
583 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
584 }
585 *(reinterpret_cast<T *>(output.ptr())) = res;
586 break;
587 }
588 default:
589 ARM_COMPUTE_ERROR("Not supported");
giuros01154bc1c2019-03-26 17:44:40 +0000590 }
giuros01154bc1c2019-03-26 17:44:40 +0000591 },
592 input, output);
593 }
594};
595
Luca Foschianiee939fb2020-01-28 10:38:07 +0000596template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100597struct RedOpX_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100598{
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100599 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100600 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000601 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
602
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100603 const TensorInfo in_info = *(in->info());
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100604 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
605
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100606 Iterator input(in, in_window);
607 Iterator output(out, out_window);
608 const int window_step_x = 16 / sizeof(T);
609 const auto window_start_x = static_cast<int>(in_window.x().start());
610 const auto window_end_x = static_cast<int>(in_window.x().end());
611
612 execute_window_loop(in_window, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100613 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100614 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000615
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100616 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
617 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
618 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
619 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000620
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100621 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
622 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
623 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
624 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000625
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100626 typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = { 0 };
627
628 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100629 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100630 vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{});
631 }
632
633 uint32x4x4_t vec_res_idx{ { 0 } };
634 // Compute window_step_x elements per iteration
635 int x = window_start_x;
636 for(; x <= (window_end_x - window_step_x); x += window_step_x)
637 {
638 const auto vec_elements = wrapper::vloadq(input_ptr + x);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000639 switch(op)
640 {
641 case ReductionOperation::SUM:
642 case ReductionOperation::MEAN_SUM:
643 {
644 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
645 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100646
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000647 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
648 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
649 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
650 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100651
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000652 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
653 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
654 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
655 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
656 break;
657 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000658 case ReductionOperation::PROD:
659 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100660 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
661 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000662
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000663 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
664 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000665
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000666 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
667 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
668 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
669 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000670
Luca Foschianiee939fb2020-01-28 10:38:07 +0000671 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
672 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
673 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
674 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000675
676 //de-quantize vec_elements
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100677 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
678 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
679 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
680 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000681
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100682 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
683 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
684 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
685 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000686 break;
687 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000688 case ReductionOperation::ARG_IDX_MIN:
689 {
690 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100691 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000692 vec_res_value = temp_vec_res_value;
693 break;
694 }
695 case ReductionOperation::ARG_IDX_MAX:
696 {
697 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100698 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000699 vec_res_value = temp_vec_res_value;
700 break;
701 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100702 case ReductionOperation::MIN:
703 {
704 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
705 break;
706 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100707 case ReductionOperation::MAX:
708 {
709 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
710 break;
711 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000712 default:
713 ARM_COMPUTE_ERROR("Not supported");
714 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100715 }
716
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100717 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100718 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100719 case ReductionOperation::ARG_IDX_MIN:
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000720 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100721 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
722 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000723
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100724 // Compute left-over elements
725 for(; x < window_end_x; ++x)
726 {
727 if(*(input_ptr + x) < res)
728 {
729 idx = x;
730 res = *(input_ptr + x);
731 }
732 }
733 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
734 break;
735 }
736 case ReductionOperation::ARG_IDX_MAX:
737 {
738 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
739 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000740
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100741 // Compute left-over elements
742 for(; x < window_end_x; ++x)
743 {
744 if(*(input_ptr + x) > res)
745 {
746 idx = x;
747 res = *(input_ptr + x);
748 }
749 }
750 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
751 break;
752 }
753 case ReductionOperation::MIN:
754 {
755 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000756
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100757 // Compute left-over elements
758 for(; x < window_end_x; ++x)
759 {
760 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
761 }
762 *(reinterpret_cast<T *>(output.ptr())) = res;
763 break;
764 }
765 case ReductionOperation::MAX:
766 {
767 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000768
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100769 // Compute left-over elements
770 for(; x < window_end_x; ++x)
771 {
772 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
773 }
774 *(reinterpret_cast<T *>(output.ptr())) = res;
775 break;
776 }
777 case ReductionOperation::PROD:
778 {
779 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
780 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
781 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
782
783 float res = wrapper::vgetlane(carry_res, 0);
784 res *= wrapper::vgetlane(carry_res, 1);
785 res *= wrapper::vgetlane(carry_res, 2);
786 res *= wrapper::vgetlane(carry_res, 3);
787
788 // Compute left-over elements
789 for(; x < window_end_x; ++x)
790 {
791 //de-quantize input
792 if(std::is_same<T, uint8_t>::value)
793 {
794 res *= dequantize_qasymm8(*(input_ptr + x), iq_info);
795 }
796 else
797 {
798 res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info);
799 }
800 }
801
802 //re-quantize result
803 if(std::is_same<T, uint8_t>::value)
804 {
805 res = quantize_qasymm8(res, iq_info);
806 }
807 else
808 {
809 res = quantize_qasymm8_signed(res, iq_info);
810 }
811
812 *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
813 break;
814 }
815 case ReductionOperation::SUM:
816 case ReductionOperation::MEAN_SUM:
817 {
818 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
819 carry_res = wrapper::vadd(carry_res, vec_res_value3);
820 carry_res = wrapper::vadd(carry_res, vec_res_value4);
821
822 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
823 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
824 auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
825
826 // Compute left-over elements
827 for(; x < window_end_x; ++x)
828 {
829 res += *(input_ptr + x);
830 }
831
832 if(op == ReductionOperation::MEAN_SUM)
833 {
834 res /= static_cast<int32_t>(in_info.dimension(0));
835 }
836 else
837 {
838 // Subtract accumulated offsets
839 res -= (in_info.dimension(0) - 1) * iq_info.offset;
840 }
841 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
842 break;
843 }
844 default:
845 ARM_COMPUTE_ERROR("Not supported");
846 }
847 },
848 input, output);
849 }
850};
851
852template <typename T, int S>
853struct RedOpYZW
854{
855 /** NEON vector tag type. */
856 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
857 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
858
859 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
860 {
861 const TensorInfo in_info = *(in->info());
862
863 Iterator input(in, in_window);
864 Iterator output(out, out_window);
865 const int window_step_x = 16 / sizeof(T);
866 const auto window_start_x = static_cast<int>(in_window.x().start());
867 const auto window_end_x = static_cast<int>(in_window.x().end());
868
869 execute_window_loop(in_window, [&](const Coordinates &)
870 {
871 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
872
873 // Compute window_step_x elements per iteration
874 int x = window_start_x;
875 for(; x <= (window_end_x - window_step_x); x += window_step_x)
876 {
877 neon_vector vec_res_value = { 0 };
878 switch(op)
879 {
880 case ReductionOperation::ARG_IDX_MAX:
881 case ReductionOperation::ARG_IDX_MIN:
882 case ReductionOperation::MIN:
883 case ReductionOperation::MAX:
884 {
885 vec_res_value = wrapper::vloadq(input_ptr + x);
886 break;
887 }
888 case ReductionOperation::PROD:
889 {
890 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
891 break;
892 }
893 default:
894 {
895 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
896 break;
897 }
898 }
899 uint32x4x4_t vec_res_idx{ { 0 } };
900
901 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
902 {
903 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
904 const auto vec_elements = wrapper::vloadq(in_ptr);
905 switch(op)
906 {
907 case ReductionOperation::SUM:
908 case ReductionOperation::MEAN_SUM:
909 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
910 break;
911 case ReductionOperation::SUM_SQUARE:
912 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
913 break;
914 case ReductionOperation::PROD:
915 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
916 break;
917 case ReductionOperation::ARG_IDX_MIN:
918 {
919 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
920 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
921 vec_res_value = temp_vec_res_value;
922 break;
923 }
924 case ReductionOperation::ARG_IDX_MAX:
925 {
926 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
927 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
928 vec_res_value = temp_vec_res_value;
929 break;
930 }
931 case ReductionOperation::MIN:
932 {
933 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
934 break;
935 }
936 case ReductionOperation::MAX:
937 {
938 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
939 break;
940 }
941 default:
942 ARM_COMPUTE_ERROR("Not supported");
943 }
944 }
945
946 if(op == ReductionOperation::MEAN_SUM)
947 {
948 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
949 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
950 }
951
952 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
953 {
954 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
955#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
956 if(std::is_same<T, float16_t>::value)
957 {
958 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
959 }
960#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000961 }
962 else
963 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100964 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000965 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100966 }
967
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100968 // Compute left-over elements
969 for(; x < window_end_x; ++x)
970 {
971 auto res_value = 0.f;
972 switch(op)
973 {
974 case ReductionOperation::ARG_IDX_MAX:
975 case ReductionOperation::ARG_IDX_MIN:
976 case ReductionOperation::MIN:
977 case ReductionOperation::MAX:
978 {
979 res_value = *(input_ptr + x);
980 break;
981 }
982 case ReductionOperation::PROD:
983 {
984 res_value = static_cast<T>(1.f);
985 break;
986 }
987 default:
988 {
989 res_value = static_cast<T>(0.f);
990 break;
991 }
992 }
993
994 uint32_t res_idx = 0;
995 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
996 {
997 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
998
999 switch(op)
1000 {
1001 case ReductionOperation::SUM:
1002 case ReductionOperation::MEAN_SUM:
1003 res_value += *in_ptr;
1004 break;
1005 case ReductionOperation::SUM_SQUARE:
1006 res_value += *in_ptr * *in_ptr;
1007 break;
1008 case ReductionOperation::PROD:
1009 res_value *= *in_ptr;
1010 break;
1011 case ReductionOperation::ARG_IDX_MIN:
1012 {
1013 if(*in_ptr < res_value)
1014 {
1015 res_value = *in_ptr;
1016 res_idx = dim;
1017 }
1018 break;
1019 }
1020 case ReductionOperation::ARG_IDX_MAX:
1021 {
1022 if(*in_ptr > res_value)
1023 {
1024 res_value = *in_ptr;
1025 res_idx = dim;
1026 }
1027 break;
1028 }
1029 case ReductionOperation::MIN:
1030 {
1031 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1032 break;
1033 }
1034 case ReductionOperation::MAX:
1035 {
1036 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1037 break;
1038 }
1039 default:
1040 ARM_COMPUTE_ERROR("Not supported");
1041 }
1042 }
1043
1044 if(op == ReductionOperation::MEAN_SUM)
1045 {
1046 res_value /= in_info.dimension(axis);
1047 }
1048
1049 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1050 {
1051 *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
1052 }
1053 else
1054 {
1055 *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
1056 }
1057 }
1058 },
1059 input, output);
1060 }
1061};
1062
1063template <typename T, int S, int axis, ReductionOperation op>
1064struct RedOpYZW_complex
1065{
1066 /** NEON vector tag type. */
1067 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
1068 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
1069
1070 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
1071 {
1072 ARM_COMPUTE_ERROR_ON(axis != 2);
1073
1074 const TensorInfo in_info = *(in->info());
1075
1076 Iterator input(in, in_window);
1077 Iterator output(out, out_window);
1078 const int window_step_x = 16 / sizeof(T);
1079 const auto window_start_x = static_cast<int>(in_window.x().start());
1080 const auto window_end_x = static_cast<int>(in_window.x().end());
1081
1082 const size_t stride_z = in_info.strides_in_bytes()[axis];
1083
1084 execute_window_loop(in_window, [&](const Coordinates &)
1085 {
1086 // Compute window_step_x elements per iteration
1087 int x = window_start_x;
1088 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1089 {
1090 neon_vector vec_res_value_0 = { 0 };
1091 neon_vector vec_res_value_1 = { 0 };
1092
1093 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1094 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1095
1096 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1097 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1098 {
1099 T *in_ptr_0;
1100 T *in_ptr_1;
1101 switch(axis)
1102 {
1103 case 2:
1104 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1105 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
1106 break;
1107 default:
1108 ARM_COMPUTE_ERROR("Not supported");
1109 }
1110 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
1111 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
1112
1113 switch(op)
1114 {
1115 case ReductionOperation::SUM:
1116 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
1117 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
1118 break;
1119 default:
1120 ARM_COMPUTE_ERROR("Not supported");
1121 }
1122 }
1123
1124 wrapper::vstore(out_ptr, vec_res_value_0);
1125 wrapper::vstore(out_ptr + 4, vec_res_value_1);
1126 }
1127
1128 // Compute left-over elements
1129 for(; x < window_end_x; ++x)
1130 {
1131 auto res_value_0 = 0.f;
1132 auto res_value_1 = 0.f;
1133
1134 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1135 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1136 {
1137 T *in_ptr;
1138 switch(axis)
1139 {
1140 case 2:
1141 in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1142 break;
1143 default:
1144 ARM_COMPUTE_ERROR("Not supported");
1145 }
1146 switch(op)
1147 {
1148 case ReductionOperation::SUM:
1149 res_value_0 += *in_ptr;
1150 res_value_1 += *(in_ptr + 1);
1151 break;
1152 default:
1153 ARM_COMPUTE_ERROR("Not supported");
1154 }
1155 }
1156 *out_ptr = res_value_0;
1157 *(out_ptr + 1) = res_value_1;
1158 }
1159 },
1160 input, output);
1161 }
1162};
1163
1164template <typename T>
1165struct RedOpYZW_quantized
1166{
1167 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
1168 {
1169 const TensorInfo in_info = *(in->info());
1170
1171 Iterator input(in, in_window);
1172 Iterator output(out, out_window);
1173 const int window_step_x = 16 / sizeof(T);
1174 const auto window_start_x = static_cast<int>(in_window.x().start());
1175 const auto window_end_x = static_cast<int>(in_window.x().end());
1176
1177 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
1178
1179 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
1180
1181 execute_window_loop(in_window, [&](const Coordinates &)
1182 {
1183 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
1184
1185 // Compute window_step_x elements per iteration
1186 int x = window_start_x;
1187 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1188 {
1189 uint32x4x4_t vec_res_idx{ { 0 } };
1190 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1191 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1192 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1193 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1194
1195 auto vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1196 auto vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1197 auto vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1198 auto vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1199
1200 auto vec_res_value = wrapper::vloadq(input_ptr + x);
1201
1202 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
1203 {
1204 const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim;
1205 const auto vec_elements = wrapper::vloadq(in_ptr);
1206 switch(op)
1207 {
1208 case ReductionOperation::SUM:
1209 case ReductionOperation::MEAN_SUM:
1210 {
1211 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1212 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1213
1214 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1215 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1216 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1217 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1218
1219 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
1220 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
1221 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
1222 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
1223 break;
1224 }
1225 case ReductionOperation::PROD:
1226 {
1227 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1228 const auto scale32x4f_4 = wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
1229
1230 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1231 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1232
1233 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1234 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1235 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1236 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1237
1238 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
1239 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
1240 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
1241 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
1242
1243 //de-quantize vec_elements
1244 temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
1245 temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
1246 temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
1247 temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
1248
1249 vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
1250 vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
1251 vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
1252 vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
1253 break;
1254 }
1255 case ReductionOperation::ARG_IDX_MIN:
1256 {
1257 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1258 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1259 vec_res_value = temp_vec_res_value;
1260 break;
1261 }
1262 case ReductionOperation::ARG_IDX_MAX:
1263 {
1264 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1265 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1266 vec_res_value = temp_vec_res_value;
1267 break;
1268 }
1269 case ReductionOperation::MIN:
1270 {
1271 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1272 break;
1273 }
1274 case ReductionOperation::MAX:
1275 {
1276 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1277 break;
1278 }
1279 default:
1280 ARM_COMPUTE_ERROR("Not supported");
1281 }
1282 }
1283
1284 switch(op)
1285 {
1286 case ReductionOperation::ARG_IDX_MIN:
1287 case ReductionOperation::ARG_IDX_MAX:
1288 {
1289 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]);
1290 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]);
1291 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]);
1292 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, vec_res_idx.val[3]);
1293 break;
1294 }
1295 case ReductionOperation::MIN:
1296 case ReductionOperation::MAX:
1297 {
1298 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value);
1299 break;
1300 }
1301 case ReductionOperation::SUM:
1302 {
1303 // Subtract offsets
1304 auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1305
1306 auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1307 auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1308 auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1309 auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
1310
1311 vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1312 vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1313 vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1314 vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1315
1316 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1317 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
1318
1319 combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x);
1320 break;
1321 }
1322 case ReductionOperation::MEAN_SUM:
1323 {
1324 const auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<float>(in_info.dimension(axis)), wrapper::traits::vector_128_tag{}));
1325 vec_res_value1_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value1), vec_width_inv);
1326 vec_res_value2_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value2), vec_width_inv);
1327 vec_res_value3_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value3), vec_width_inv);
1328 vec_res_value4_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value4), vec_width_inv);
1329
1330 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1331 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1332 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1333 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1334
1335 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1336 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1337 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1338
1339 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1340 break;
1341 }
1342 case ReductionOperation::PROD:
1343 {
1344 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1345 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
1346
1347 //re-quantize
1348 vec_res_value1_f = wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1349 vec_res_value2_f = wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1350 vec_res_value3_f = wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1351 vec_res_value4_f = wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1352
1353 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1354 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1355 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1356 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1357
1358 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1359 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1360 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1361
1362 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1363 break;
1364 }
1365 default:
1366 ARM_COMPUTE_ERROR("Not supported");
1367 }
1368 }
1369
1370 // Compute left-over elements
1371 for(; x < window_end_x; ++x)
1372 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001373 float res_value = 0.f;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001374 switch(op)
1375 {
1376 case ReductionOperation::ARG_IDX_MAX:
1377 case ReductionOperation::ARG_IDX_MIN:
1378 case ReductionOperation::MIN:
1379 case ReductionOperation::MAX:
1380 {
1381 res_value = *(input_ptr + x);
1382 break;
1383 }
1384 case ReductionOperation::PROD:
1385 {
1386 res_value = static_cast<T>(1.0f);
1387 break;
1388 }
1389 default:
1390 {
1391 res_value = static_cast<T>(0.0f);
1392 break;
1393 }
1394 }
1395 uint32_t res_idx = 0;
1396
1397 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1398 {
1399 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim);
1400 switch(op)
1401 {
1402 case ReductionOperation::SUM:
1403 case ReductionOperation::MEAN_SUM:
1404 {
1405 res_value += *in_ptr;
1406 break;
1407 }
1408 case ReductionOperation::SUM_SQUARE:
1409 {
1410 res_value += *in_ptr * *in_ptr;
1411 break;
1412 }
1413 case ReductionOperation::PROD:
1414 {
1415 //de-quantize input
1416 if(std::is_same<T, uint8_t>::value)
1417 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001418 res_value *= dequantize_qasymm8(*in_ptr, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001419 }
1420 else
1421 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001422 res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001423 }
1424 break;
1425 }
1426 case ReductionOperation::ARG_IDX_MIN:
1427 {
1428 if(*in_ptr < res_value)
1429 {
1430 res_value = *in_ptr;
1431 res_idx = dim;
1432 }
1433 break;
1434 }
1435 case ReductionOperation::ARG_IDX_MAX:
1436 {
1437 if(*in_ptr > res_value)
1438 {
1439 res_value = *in_ptr;
1440 res_idx = dim;
1441 }
1442 break;
1443 }
1444 case ReductionOperation::MIN:
1445 {
1446 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1447 break;
1448 }
1449 case ReductionOperation::MAX:
1450 {
1451 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1452 break;
1453 }
1454 default:
1455 ARM_COMPUTE_ERROR("Not supported");
1456 }
1457 }
1458
1459 switch(op)
1460 {
1461 case ReductionOperation::MEAN_SUM:
1462 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001463 int32_t res = static_cast<int32_t>(res_value);
Sang-Hoon Parkcbede282020-10-12 21:44:23 +01001464 res /= static_cast<int32_t>(in_info.dimension(axis));
Michalis Spyrou272e4252020-10-06 17:44:40 +01001465 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001466 break;
1467 }
1468 case ReductionOperation::SUM:
1469 {
1470 // Subtract accumulated offsets
1471 res_value -= (in_info.dimension(axis) - 1) * iq_info.offset;
1472 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
1473 break;
1474 }
1475 case ReductionOperation::PROD:
1476 {
1477 //re-quantize result
Michalis Spyrou272e4252020-10-06 17:44:40 +01001478 T res = 0;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001479 if(std::is_same<T, uint8_t>::value)
1480 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001481 res = quantize_qasymm8(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001482 }
1483 else
1484 {
Michalis Spyrou272e4252020-10-06 17:44:40 +01001485 res = quantize_qasymm8_signed(res_value, iq_info);
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001486 }
Michalis Spyrou272e4252020-10-06 17:44:40 +01001487 *(reinterpret_cast<T *>(output.ptr() + x)) = res;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001488 break;
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001489 }
1490 case ReductionOperation::ARG_IDX_MIN:
1491 case ReductionOperation::ARG_IDX_MAX:
1492 {
1493 *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx;
1494 break;
1495 }
1496 default:
1497 *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
1498 }
1499 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001500 },
1501 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001502 }
1503};
1504
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001505void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001506{
giuros01154bc1c2019-03-26 17:44:40 +00001507 const bool is_complex = (input->info()->num_channels() == 2);
1508
1509 if(is_complex)
1510 {
1511 switch(axis)
1512 {
1513 case 2:
1514 switch(input->info()->data_type())
1515 {
1516 case DataType::F32:
1517 switch(op)
1518 {
1519 case ReductionOperation::SUM:
1520 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1521 default:
1522 ARM_COMPUTE_ERROR("Not supported");
1523 }
1524 default:
1525 ARM_COMPUTE_ERROR("Not supported");
1526 }
1527 default:
1528 ARM_COMPUTE_ERROR("Not supported");
1529 }
1530 }
1531
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001532 switch(axis)
1533 {
1534 case 0:
1535 switch(input->info()->data_type())
1536 {
1537 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001538 return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
1539 case DataType::QASYMM8_SIGNED:
1540 return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001541#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1542 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001543 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001544#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1545 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001546 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001547 case DataType::S32:
1548 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001549 default:
1550 ARM_COMPUTE_ERROR("Not supported");
1551 }
1552 case 1:
1553 switch(input->info()->data_type())
1554 {
1555 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001556 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1557 case DataType::QASYMM8_SIGNED:
1558 return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001559#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1560 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001561 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001562#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1563 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001564 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001565 case DataType::S32:
1566 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001567 default:
1568 ARM_COMPUTE_ERROR("Not supported");
1569 }
1570 case 2:
1571 switch(input->info()->data_type())
1572 {
1573 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001574 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1575 case DataType::QASYMM8_SIGNED:
1576 return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001577#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1578 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001579 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001580#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1581 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001582 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001583 case DataType::S32:
1584 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001585 default:
1586 ARM_COMPUTE_ERROR("Not supported");
1587 }
1588 case 3:
1589 switch(input->info()->data_type())
1590 {
1591 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001592 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1593 case DataType::QASYMM8_SIGNED:
1594 return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001595#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1596 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001597 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001598#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1599 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001600 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001601 case DataType::S32:
1602 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001603 default:
1604 ARM_COMPUTE_ERROR("Not supported");
1605 }
1606 default:
1607 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1608 }
1609}
John Richardson73d4aef2018-05-08 14:34:33 +01001610
1611Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1612{
1613 ARM_COMPUTE_UNUSED(op);
1614
1615 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001616 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001617
1618 if(input->num_channels() == 1)
1619 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001620 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001621 }
1622 else
1623 {
1624 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1625 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1626 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1627 }
John Richardson73d4aef2018-05-08 14:34:33 +01001628
1629 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001630 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001631
1632 if(output->total_size() != 0)
1633 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001634 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1635 if(!is_arg_min_max)
1636 {
1637 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001638 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001639 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001640 }
1641 else
1642 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001643 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001644 }
John Richardson73d4aef2018-05-08 14:34:33 +01001645
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001646 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001647 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1648 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1649 }
1650
1651 return Status{};
1652}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001653} // namespace
1654
1655NEReductionOperationKernel::NEReductionOperationKernel()
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001656 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001657{
1658}
1659
Georgios Pinitasd9769582017-08-03 10:19:40 +01001660void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1661{
1662 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001663
John Richardson73d4aef2018-05-08 14:34:33 +01001664 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001665
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001666 _input = input;
1667 _output = output;
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001668 _op = op;
1669 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001670
1671 // Configure kernel window
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001672 Coordinates coord;
1673 coord.set_num_dimensions(input->info()->num_dimensions());
1674 input->info()->set_valid_region(ValidRegion(coord, input->info()->tensor_shape()));
1675 Window win = calculate_max_window(*input->info(), Steps(input->info()->dimension(0)));
1676 INEKernel::configure(win);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001677
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001678 // Calculate output shape and set if empty
1679 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis);
1680 // Output auto initialization if not yet initialized
1681 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1682 DataType output_data_type = is_arg_min_max ? DataType::S32 : input->info()->data_type();
1683 auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
1684 output->info()->set_valid_region(ValidRegion(coord, output_shape));
John Richardson73d4aef2018-05-08 14:34:33 +01001685}
1686
1687Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1688{
1689 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
John Richardson73d4aef2018-05-08 14:34:33 +01001690
1691 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001692}
1693
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001694void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001695{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001696 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001697 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1698 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1699
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001700 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001701}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001702} // namespace arm_compute