blob: 01534f77b46e056d85f6f5972b8537767aadc8c0 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000026#include "arm_compute/core/CPP/Validate.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/INEKernel.h"
John Richardson73d4aef2018-05-08 14:34:33 +010032#include "arm_compute/core/TensorInfo.h"
Luca Foschianiee939fb2020-01-28 10:38:07 +000033#include "arm_compute/core/Utils.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010034#include "arm_compute/core/Validate.h"
Michalis Spyrou19bd4122020-01-22 10:27:06 +000035#include "arm_compute/core/utils/misc/SaturateCast.h"
Michalis Spyrouaea14c62019-01-03 11:10:25 +000036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010037#include "src/core/NEON/NEMath.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010038
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010039#include "src/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010040#include <arm_neon.h>
41
Michalis Spyroubcf8a962018-10-12 10:51:31 +010042namespace arm_compute
43{
Georgios Pinitasd9769582017-08-03 10:19:40 +010044namespace
45{
Luca Foschianiee939fb2020-01-28 10:38:07 +000046// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
47template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +010048void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
Luca Foschianiee939fb2020-01-28 10:38:07 +000049{
50 if(std::is_same<T, uint8_t>::value)
51 {
52 auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010053 wrapper::vstore(output.ptr() + offset, res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000054 }
55 else
56 {
57 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
Sheri Zhang4d91dc62020-09-23 11:22:50 +010058 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
Luca Foschianiee939fb2020-01-28 10:38:07 +000059 }
60}
61
Michalis Spyroub9626ab2019-05-13 17:41:01 +010062template <typename T>
63uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000064{
65 uint32x4_t mask{ 0 };
66 if(op == ReductionOperation::ARG_IDX_MIN)
67 {
68 mask = wrapper::vcgt(b, a);
69 }
70 else
71 {
72 mask = wrapper::vclt(b, a);
73 }
74
75 uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
76 if(axis != 0)
77 {
78 vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
79 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +000080 uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000081
82 return res;
83}
84
Luca Foschianiee939fb2020-01-28 10:38:07 +000085template <typename T>
86uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
Michalis Spyrouaea14c62019-01-03 11:10:25 +000087{
Georgios Pinitasd57891a2019-02-19 18:10:03 +000088 uint32x4x4_t mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +000089 uint8x16_t mask_u8{ 0 };
90 if(op == ReductionOperation::ARG_IDX_MIN)
91 {
92 mask_u8 = wrapper::vcgt(b, a);
93 }
94 else
95 {
96 mask_u8 = wrapper::vclt(b, a);
97 }
Michalis Spyrou254a48a2019-01-14 17:27:39 +000098 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
99 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
100 mask.val[0] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
101 mask.val[1] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
102 mask.val[2] = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
103 mask.val[3] = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
104
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000105 uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
106 { idx + 4, idx + 5, idx + 6, idx + 7 },
107 { idx + 8, idx + 9, idx + 10, idx + 11 },
108 { idx + 12, idx + 13, idx + 14, idx + 15 }
109 }
110 };
111 if(axis != 0)
112 {
113 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
114 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
115 vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
116 vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
117 }
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000118 uint32x4x4_t res =
119 {
120 {
121 vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
122 vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
123 vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
124 vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
125 }
126 };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000127
128 return res;
129}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100130
131// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000132template <typename T>
133inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
134 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
135 calculate_min(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100136{
137 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
138 return wrapper::vpmin(pmin, pmin);
139}
140
Luca Foschianiee939fb2020-01-28 10:38:07 +0000141// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
142template <typename T>
143inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
144 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
145 calculate_min(T in)
146{
147 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
148 pmin = wrapper::vpmin(pmin, pmin);
149 pmin = wrapper::vpmin(pmin, pmin);
150 return wrapper::vpmin(pmin, pmin);
151}
152
Usama Arifa4a08ad2019-05-20 12:38:33 +0100153// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000154template <typename T>
155inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
156 typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
157 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100158{
159 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
160 return wrapper::vpmax(pmax, pmax);
161}
Usama Arifa4a08ad2019-05-20 12:38:33 +0100162
163// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
Luca Foschianiee939fb2020-01-28 10:38:07 +0000164template <typename T>
165inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
166 typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
167 calculate_max(T in)
Usama Arifa4a08ad2019-05-20 12:38:33 +0100168{
169 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000170 pmax = wrapper::vpmax(pmax, pmax);
171 pmax = wrapper::vpmax(pmax, pmax);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100172 return wrapper::vpmax(pmax, pmax);
173}
174
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100175template <typename T>
176uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000177{
178 uint32x4_t res_idx_mask{ 0 };
179 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
180
181 if(op == ReductionOperation::ARG_IDX_MIN)
182 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100183 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000184 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
185 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
186 }
187 else
188 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100189 auto pmax = calculate_max(vec_res_value);
Michalis Spyroub9626ab2019-05-13 17:41:01 +0100190 auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000191 res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
192 }
193
194 res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
195 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
196 pmin = wrapper::vpmin(pmin, pmin);
197 uint32_t res = wrapper::vgetlane(pmin, 0);
198
199 return (res - 0xFFFFFFFF);
200}
201
Luca Foschianiee939fb2020-01-28 10:38:07 +0000202template <typename T>
203uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000204{
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000205 uint32x4x4_t res_idx_mask{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000206 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
207 uint8x16_t mask_u8{ 0 };
208 if(op == ReductionOperation::ARG_IDX_MIN)
209 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100210 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000211 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
212 }
213 else
214 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100215 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000216 mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
217 }
218
219 // Widen vectors
220 auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
221 auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
222 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
223 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
224 auto wide_u32_3 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
225 auto wide_u32_4 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
226 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
227 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
228 res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
229 res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
230 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
231 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
232 res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
233 res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
234
235 uint32_t res = 0xFFFFFFFF;
236 int iter = 0;
237 do
238 {
239 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
240 pmin = wrapper::vpmin(pmin, pmin);
241 res = std::min(wrapper::vgetlane(pmin, 0), res);
242 iter++;
243 }
244 while(iter < 4);
245
246 return (res - 0xFFFFFFFF);
247}
Luca Foschianiee939fb2020-01-28 10:38:07 +0000248
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000249#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasfad18382019-06-05 15:12:22 +0100250template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000251uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
252{
253 uint32x4x2_t mask{ 0 };
254 uint16x8_t mask_u16{ 0 };
255 if(op == ReductionOperation::ARG_IDX_MIN)
256 {
257 mask_u16 = wrapper::vcgt(b, a);
258 }
259 else
260 {
261 mask_u16 = wrapper::vclt(b, a);
262 }
263 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
264 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
265 uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
266 { idx + 4, idx + 5, idx + 6, idx + 7 }
267 }
268 };
269 if(axis != 0)
270 {
271 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
272 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
273 }
274 uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
275 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
276 0, 0
277 };
278
279 return res;
280}
281
Usama Arifa4a08ad2019-05-20 12:38:33 +0100282// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
283inline float16x4_t calculate_min(float16x8_t in)
284{
285 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
286 pmin = wrapper::vpmin(pmin, pmin);
287 return wrapper::vpmin(pmin, pmin);
288}
289// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
290inline float16x4_t calculate_max(float16x8_t in)
291{
292 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
293 pmax = wrapper::vpmax(pmax, pmax);
294 return wrapper::vpmax(pmax, pmax);
295}
296
Usama Arif0a5a57a2019-05-23 14:20:33 +0100297template <>
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000298uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
299{
300 uint32x4x2_t res_idx_mask{ 0 };
301 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
302 uint16x8_t mask_u16;
303 if(op == ReductionOperation::ARG_IDX_MIN)
304 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100305 auto pmin = calculate_min(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000306 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
307 }
308 else
309 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100310 auto pmax = calculate_max(vec_res_value);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000311 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
312 }
313
314 // Widen vectors
315 auto wide_u32_1 = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
316 auto wide_u32_2 = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
317 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
318 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
319 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
320 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
321
322 uint32_t res = 0xFFFFFFFF;
323 int iter = 0;
324 do
325 {
326 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
327 pmin = wrapper::vpmin(pmin, pmin);
328 res = std::min(wrapper::vgetlane(pmin, 0), res);
329 iter++;
330 }
331 while(iter < 2);
332
333 return (res - 0xFFFFFFFF);
334}
335#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
336
Georgios Pinitasd9769582017-08-03 10:19:40 +0100337template <class F>
338class Reducer
339{
340public:
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000341 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100342 {
343 // Set out window
344 Window out_window(window);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100345 out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100346
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100347 f(window, out_window, input, output, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100348 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000349 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100350 {
351 // Set in window
352 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000353 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100354
355 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000356 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100357
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100358 f(in_window, out_window, input, output, 1, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100359 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000360 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100361 {
362 // Set in window
363 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000364 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100365
366 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +0000367 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100368
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100369 f(in_window, out_window, input, output, 2, op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100370 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000371 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100372 {
373 // Set in/out window
374 Window in_window(window);
375 Window out_window(window);
376
377 in_window.set(3, Window::Dimension(0, 1, 1));
378 out_window.set(3, Window::Dimension(0, 1, 1));
379
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100380 f(in_window, out_window, input, output, 3, op);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100381 }
382};
383
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000384template <typename T, int S>
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100385struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100386{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100387 /** NEON vector tag type. */
388 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
389
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100390 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100391 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100392 const TensorInfo in_info = *(in->info());
393
394 Iterator input(in, in_window);
395 Iterator output(out, out_window);
396 const int window_step_x = 16 / sizeof(T);
397 const auto window_start_x = static_cast<int>(in_window.x().start());
398 const auto window_end_x = static_cast<int>(in_window.x().end());
399
400 execute_window_loop(in_window, [&](const Coordinates &)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000401 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100402 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
Georgios Pinitasd9769582017-08-03 10:19:40 +0100403
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100404 auto init_res_value = static_cast<T>(0.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100405 switch(op)
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000406 {
Usama Arifa4a08ad2019-05-20 12:38:33 +0100407 case ReductionOperation::ARG_IDX_MAX:
408 case ReductionOperation::ARG_IDX_MIN:
409 case ReductionOperation::MIN:
Usama Arif28f0dd92019-05-20 13:44:34 +0100410 case ReductionOperation::MAX:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100411 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100412 init_res_value = static_cast<T>(*input_ptr);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100413 break;
414 }
415 case ReductionOperation::PROD:
416 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100417 init_res_value = static_cast<T>(1.f);
Usama Arifa4a08ad2019-05-20 12:38:33 +0100418 break;
419 }
420 default:
Usama Arifa4a08ad2019-05-20 12:38:33 +0100421 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000422 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100423 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000424 uint32x4x4_t vec_res_idx{ { 0 } };
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000425
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100426 // Compute window_step_x elements per iteration
427 int x = window_start_x;
428 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100429 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100430 const auto vec_elements = wrapper::vloadq(input_ptr + x);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000431 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100432 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000433 case ReductionOperation::SUM_SQUARE:
434 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
435 break;
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100436 case ReductionOperation::MEAN_SUM:
437 case ReductionOperation::SUM:
438 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
439 break;
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000440 case ReductionOperation::PROD:
441 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
442 break;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000443 case ReductionOperation::ARG_IDX_MIN:
444 {
445 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100446 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000447 vec_res_value = temp_vec_res_value;
448 break;
449 }
450 case ReductionOperation::ARG_IDX_MAX:
451 {
452 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100453 vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000454 vec_res_value = temp_vec_res_value;
455 break;
456 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100457 case ReductionOperation::MIN:
458 {
459 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
460 break;
461 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100462 case ReductionOperation::MAX:
463 {
464 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
465 break;
466 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000467 default:
468 ARM_COMPUTE_ERROR("Not supported");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100469 }
470 }
471
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100472 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100473 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100474 case ReductionOperation::SUM:
475 case ReductionOperation::MEAN_SUM:
476 case ReductionOperation::SUM_SQUARE:
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100477 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100478 auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
479 for(int i = 0; i < S / 4; ++i)
480 {
481 carry_res = wrapper::vpadd(carry_res, carry_res);
482 }
483 auto res = wrapper::vgetlane(carry_res, 0);
484
485 if(op == ReductionOperation::SUM_SQUARE)
486 {
487 // Compute left-over elements
488 for(; x < window_end_x; ++x)
489 {
490 res += (*(input_ptr + x)) * (*(input_ptr + x));
491 }
492 }
493 else
494 {
495 // Compute left-over elements
496 for(; x < window_end_x; ++x)
497 {
498 res += *(input_ptr + x);
499 }
500 }
501
502 if(op == ReductionOperation::MEAN_SUM)
503 {
504 res /= in_info.dimension(0);
505 }
506
507 *(reinterpret_cast<T *>(output.ptr())) = res;
508 break;
Michele Di Giorgio81d7e782019-08-16 18:03:35 +0100509 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100510 case ReductionOperation::PROD:
giuros01154bc1c2019-03-26 17:44:40 +0000511 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100512 auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
513 T res = 1;
514 for(int i = 0; i < S / 2; ++i)
515 {
516 res *= wrapper::vgetlane(carry_res, i);
517 }
giuros01154bc1c2019-03-26 17:44:40 +0000518
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100519 // Compute left-over elements
520 for(; x < window_end_x; ++x)
521 {
522 res *= *(input_ptr + x);
523 }
524
525 *(reinterpret_cast<T *>(output.ptr())) = res;
526 break;
527 }
528 case ReductionOperation::ARG_IDX_MIN:
giuros01154bc1c2019-03-26 17:44:40 +0000529 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100530 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
531 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
532
533 // Compute left-over elements
534 for(; x < window_end_x; ++x)
535 {
536 if(*(input_ptr + x) < res)
537 {
538 idx = x;
539 res = *(input_ptr + x);
540 }
541 }
542 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
543 break;
giuros01154bc1c2019-03-26 17:44:40 +0000544 }
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100545 case ReductionOperation::ARG_IDX_MAX:
546 {
547 auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
548 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
549
550 // Compute left-over elements
551 for(; x < window_end_x; ++x)
552 {
553 if(*(input_ptr + x) > res)
554 {
555 idx = x;
556 res = *(input_ptr + x);
557 }
558 }
559 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
560 break;
561 }
562 case ReductionOperation::MIN:
563 {
564 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
565
566 // Compute left-over elements
567 for(; x < window_end_x; ++x)
568 {
569 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
570 }
571 *(reinterpret_cast<T *>(output.ptr())) = res;
572 break;
573 }
574 case ReductionOperation::MAX:
575 {
576 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
577
578 // Compute left-over elements
579 for(; x < window_end_x; ++x)
580 {
581 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
582 }
583 *(reinterpret_cast<T *>(output.ptr())) = res;
584 break;
585 }
586 default:
587 ARM_COMPUTE_ERROR("Not supported");
giuros01154bc1c2019-03-26 17:44:40 +0000588 }
giuros01154bc1c2019-03-26 17:44:40 +0000589 },
590 input, output);
591 }
592};
593
Luca Foschianiee939fb2020-01-28 10:38:07 +0000594template <typename T>
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100595struct RedOpX_quantized
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100596{
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100597 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100598 {
Luca Foschianiee939fb2020-01-28 10:38:07 +0000599 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
600
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100601 const TensorInfo in_info = *(in->info());
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100602 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
603
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100604 Iterator input(in, in_window);
605 Iterator output(out, out_window);
606 const int window_step_x = 16 / sizeof(T);
607 const auto window_start_x = static_cast<int>(in_window.x().start());
608 const auto window_end_x = static_cast<int>(in_window.x().end());
609
610 execute_window_loop(in_window, [&](const Coordinates &)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100611 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100612 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000613
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100614 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
615 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
616 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
617 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000618
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100619 auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
620 auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
621 auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
622 auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000623
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100624 typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = { 0 };
625
626 if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100627 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100628 vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{});
629 }
630
631 uint32x4x4_t vec_res_idx{ { 0 } };
632 // Compute window_step_x elements per iteration
633 int x = window_start_x;
634 for(; x <= (window_end_x - window_step_x); x += window_step_x)
635 {
636 const auto vec_elements = wrapper::vloadq(input_ptr + x);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000637 switch(op)
638 {
639 case ReductionOperation::SUM:
640 case ReductionOperation::MEAN_SUM:
641 {
642 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
643 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100644
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000645 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
646 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
647 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
648 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100649
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000650 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
651 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
652 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
653 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
654 break;
655 }
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000656 case ReductionOperation::PROD:
657 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100658 const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
659 const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000660
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000661 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
662 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000663
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000664 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
665 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
666 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
667 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000668
Luca Foschianiee939fb2020-01-28 10:38:07 +0000669 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
670 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
671 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
672 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000673
674 //de-quantize vec_elements
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100675 temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
676 temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
677 temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
678 temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000679
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100680 vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
681 vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
682 vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
683 vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000684 break;
685 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000686 case ReductionOperation::ARG_IDX_MIN:
687 {
688 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100689 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000690 vec_res_value = temp_vec_res_value;
691 break;
692 }
693 case ReductionOperation::ARG_IDX_MAX:
694 {
695 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100696 vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000697 vec_res_value = temp_vec_res_value;
698 break;
699 }
Usama Arifa4a08ad2019-05-20 12:38:33 +0100700 case ReductionOperation::MIN:
701 {
702 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
703 break;
704 }
Usama Arif28f0dd92019-05-20 13:44:34 +0100705 case ReductionOperation::MAX:
706 {
707 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
708 break;
709 }
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000710 default:
711 ARM_COMPUTE_ERROR("Not supported");
712 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100713 }
714
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100715 switch(op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100716 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100717 case ReductionOperation::ARG_IDX_MIN:
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000718 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100719 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
720 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000721
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100722 // Compute left-over elements
723 for(; x < window_end_x; ++x)
724 {
725 if(*(input_ptr + x) < res)
726 {
727 idx = x;
728 res = *(input_ptr + x);
729 }
730 }
731 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
732 break;
733 }
734 case ReductionOperation::ARG_IDX_MAX:
735 {
736 auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
737 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000738
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100739 // Compute left-over elements
740 for(; x < window_end_x; ++x)
741 {
742 if(*(input_ptr + x) > res)
743 {
744 idx = x;
745 res = *(input_ptr + x);
746 }
747 }
748 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
749 break;
750 }
751 case ReductionOperation::MIN:
752 {
753 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000754
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100755 // Compute left-over elements
756 for(; x < window_end_x; ++x)
757 {
758 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
759 }
760 *(reinterpret_cast<T *>(output.ptr())) = res;
761 break;
762 }
763 case ReductionOperation::MAX:
764 {
765 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
Luca Foschianiee939fb2020-01-28 10:38:07 +0000766
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100767 // Compute left-over elements
768 for(; x < window_end_x; ++x)
769 {
770 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
771 }
772 *(reinterpret_cast<T *>(output.ptr())) = res;
773 break;
774 }
775 case ReductionOperation::PROD:
776 {
777 auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
778 carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
779 carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
780
781 float res = wrapper::vgetlane(carry_res, 0);
782 res *= wrapper::vgetlane(carry_res, 1);
783 res *= wrapper::vgetlane(carry_res, 2);
784 res *= wrapper::vgetlane(carry_res, 3);
785
786 // Compute left-over elements
787 for(; x < window_end_x; ++x)
788 {
789 //de-quantize input
790 if(std::is_same<T, uint8_t>::value)
791 {
792 res *= dequantize_qasymm8(*(input_ptr + x), iq_info);
793 }
794 else
795 {
796 res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info);
797 }
798 }
799
800 //re-quantize result
801 if(std::is_same<T, uint8_t>::value)
802 {
803 res = quantize_qasymm8(res, iq_info);
804 }
805 else
806 {
807 res = quantize_qasymm8_signed(res, iq_info);
808 }
809
810 *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
811 break;
812 }
813 case ReductionOperation::SUM:
814 case ReductionOperation::MEAN_SUM:
815 {
816 auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
817 carry_res = wrapper::vadd(carry_res, vec_res_value3);
818 carry_res = wrapper::vadd(carry_res, vec_res_value4);
819
820 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
821 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
822 auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
823
824 // Compute left-over elements
825 for(; x < window_end_x; ++x)
826 {
827 res += *(input_ptr + x);
828 }
829
830 if(op == ReductionOperation::MEAN_SUM)
831 {
832 res /= static_cast<int32_t>(in_info.dimension(0));
833 }
834 else
835 {
836 // Subtract accumulated offsets
837 res -= (in_info.dimension(0) - 1) * iq_info.offset;
838 }
839 *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
840 break;
841 }
842 default:
843 ARM_COMPUTE_ERROR("Not supported");
844 }
845 },
846 input, output);
847 }
848};
849
850template <typename T, int S>
851struct RedOpYZW
852{
853 /** NEON vector tag type. */
854 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
855 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
856
857 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
858 {
859 const TensorInfo in_info = *(in->info());
860
861 Iterator input(in, in_window);
862 Iterator output(out, out_window);
863 const int window_step_x = 16 / sizeof(T);
864 const auto window_start_x = static_cast<int>(in_window.x().start());
865 const auto window_end_x = static_cast<int>(in_window.x().end());
866
867 execute_window_loop(in_window, [&](const Coordinates &)
868 {
869 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
870
871 // Compute window_step_x elements per iteration
872 int x = window_start_x;
873 for(; x <= (window_end_x - window_step_x); x += window_step_x)
874 {
875 neon_vector vec_res_value = { 0 };
876 switch(op)
877 {
878 case ReductionOperation::ARG_IDX_MAX:
879 case ReductionOperation::ARG_IDX_MIN:
880 case ReductionOperation::MIN:
881 case ReductionOperation::MAX:
882 {
883 vec_res_value = wrapper::vloadq(input_ptr + x);
884 break;
885 }
886 case ReductionOperation::PROD:
887 {
888 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
889 break;
890 }
891 default:
892 {
893 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
894 break;
895 }
896 }
897 uint32x4x4_t vec_res_idx{ { 0 } };
898
899 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
900 {
901 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
902 const auto vec_elements = wrapper::vloadq(in_ptr);
903 switch(op)
904 {
905 case ReductionOperation::SUM:
906 case ReductionOperation::MEAN_SUM:
907 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
908 break;
909 case ReductionOperation::SUM_SQUARE:
910 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
911 break;
912 case ReductionOperation::PROD:
913 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
914 break;
915 case ReductionOperation::ARG_IDX_MIN:
916 {
917 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
918 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
919 vec_res_value = temp_vec_res_value;
920 break;
921 }
922 case ReductionOperation::ARG_IDX_MAX:
923 {
924 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
925 vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
926 vec_res_value = temp_vec_res_value;
927 break;
928 }
929 case ReductionOperation::MIN:
930 {
931 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
932 break;
933 }
934 case ReductionOperation::MAX:
935 {
936 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
937 break;
938 }
939 default:
940 ARM_COMPUTE_ERROR("Not supported");
941 }
942 }
943
944 if(op == ReductionOperation::MEAN_SUM)
945 {
946 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
947 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
948 }
949
950 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
951 {
952 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
953#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
954 if(std::is_same<T, float16_t>::value)
955 {
956 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
957 }
958#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000959 }
960 else
961 {
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100962 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
Michalis Spyrou19bd4122020-01-22 10:27:06 +0000963 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100964 }
965
Sheri Zhang4d91dc62020-09-23 11:22:50 +0100966 // Compute left-over elements
967 for(; x < window_end_x; ++x)
968 {
969 auto res_value = 0.f;
970 switch(op)
971 {
972 case ReductionOperation::ARG_IDX_MAX:
973 case ReductionOperation::ARG_IDX_MIN:
974 case ReductionOperation::MIN:
975 case ReductionOperation::MAX:
976 {
977 res_value = *(input_ptr + x);
978 break;
979 }
980 case ReductionOperation::PROD:
981 {
982 res_value = static_cast<T>(1.f);
983 break;
984 }
985 default:
986 {
987 res_value = static_cast<T>(0.f);
988 break;
989 }
990 }
991
992 uint32_t res_idx = 0;
993 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
994 {
995 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
996
997 switch(op)
998 {
999 case ReductionOperation::SUM:
1000 case ReductionOperation::MEAN_SUM:
1001 res_value += *in_ptr;
1002 break;
1003 case ReductionOperation::SUM_SQUARE:
1004 res_value += *in_ptr * *in_ptr;
1005 break;
1006 case ReductionOperation::PROD:
1007 res_value *= *in_ptr;
1008 break;
1009 case ReductionOperation::ARG_IDX_MIN:
1010 {
1011 if(*in_ptr < res_value)
1012 {
1013 res_value = *in_ptr;
1014 res_idx = dim;
1015 }
1016 break;
1017 }
1018 case ReductionOperation::ARG_IDX_MAX:
1019 {
1020 if(*in_ptr > res_value)
1021 {
1022 res_value = *in_ptr;
1023 res_idx = dim;
1024 }
1025 break;
1026 }
1027 case ReductionOperation::MIN:
1028 {
1029 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1030 break;
1031 }
1032 case ReductionOperation::MAX:
1033 {
1034 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1035 break;
1036 }
1037 default:
1038 ARM_COMPUTE_ERROR("Not supported");
1039 }
1040 }
1041
1042 if(op == ReductionOperation::MEAN_SUM)
1043 {
1044 res_value /= in_info.dimension(axis);
1045 }
1046
1047 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1048 {
1049 *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
1050 }
1051 else
1052 {
1053 *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
1054 }
1055 }
1056 },
1057 input, output);
1058 }
1059};
1060
1061template <typename T, int S, int axis, ReductionOperation op>
1062struct RedOpYZW_complex
1063{
1064 /** NEON vector tag type. */
1065 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
1066 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
1067
1068 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
1069 {
1070 ARM_COMPUTE_ERROR_ON(axis != 2);
1071
1072 const TensorInfo in_info = *(in->info());
1073
1074 Iterator input(in, in_window);
1075 Iterator output(out, out_window);
1076 const int window_step_x = 16 / sizeof(T);
1077 const auto window_start_x = static_cast<int>(in_window.x().start());
1078 const auto window_end_x = static_cast<int>(in_window.x().end());
1079
1080 const size_t stride_z = in_info.strides_in_bytes()[axis];
1081
1082 execute_window_loop(in_window, [&](const Coordinates &)
1083 {
1084 // Compute window_step_x elements per iteration
1085 int x = window_start_x;
1086 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1087 {
1088 neon_vector vec_res_value_0 = { 0 };
1089 neon_vector vec_res_value_1 = { 0 };
1090
1091 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1092 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1093
1094 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1095 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1096 {
1097 T *in_ptr_0;
1098 T *in_ptr_1;
1099 switch(axis)
1100 {
1101 case 2:
1102 in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1103 in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
1104 break;
1105 default:
1106 ARM_COMPUTE_ERROR("Not supported");
1107 }
1108 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
1109 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
1110
1111 switch(op)
1112 {
1113 case ReductionOperation::SUM:
1114 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
1115 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
1116 break;
1117 default:
1118 ARM_COMPUTE_ERROR("Not supported");
1119 }
1120 }
1121
1122 wrapper::vstore(out_ptr, vec_res_value_0);
1123 wrapper::vstore(out_ptr + 4, vec_res_value_1);
1124 }
1125
1126 // Compute left-over elements
1127 for(; x < window_end_x; ++x)
1128 {
1129 auto res_value_0 = 0.f;
1130 auto res_value_1 = 0.f;
1131
1132 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1133 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1134 {
1135 T *in_ptr;
1136 switch(axis)
1137 {
1138 case 2:
1139 in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1140 break;
1141 default:
1142 ARM_COMPUTE_ERROR("Not supported");
1143 }
1144 switch(op)
1145 {
1146 case ReductionOperation::SUM:
1147 res_value_0 += *in_ptr;
1148 res_value_1 += *(in_ptr + 1);
1149 break;
1150 default:
1151 ARM_COMPUTE_ERROR("Not supported");
1152 }
1153 }
1154 *out_ptr = res_value_0;
1155 *(out_ptr + 1) = res_value_1;
1156 }
1157 },
1158 input, output);
1159 }
1160};
1161
1162template <typename T>
1163struct RedOpYZW_quantized
1164{
1165 inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
1166 {
1167 const TensorInfo in_info = *(in->info());
1168
1169 Iterator input(in, in_window);
1170 Iterator output(out, out_window);
1171 const int window_step_x = 16 / sizeof(T);
1172 const auto window_start_x = static_cast<int>(in_window.x().start());
1173 const auto window_end_x = static_cast<int>(in_window.x().end());
1174
1175 using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
1176
1177 const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
1178
1179 execute_window_loop(in_window, [&](const Coordinates &)
1180 {
1181 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
1182
1183 // Compute window_step_x elements per iteration
1184 int x = window_start_x;
1185 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1186 {
1187 uint32x4x4_t vec_res_idx{ { 0 } };
1188 auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1189 auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1190 auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1191 auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1192
1193 auto vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1194 auto vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1195 auto vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1196 auto vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1197
1198 auto vec_res_value = wrapper::vloadq(input_ptr + x);
1199
1200 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
1201 {
1202 const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim;
1203 const auto vec_elements = wrapper::vloadq(in_ptr);
1204 switch(op)
1205 {
1206 case ReductionOperation::SUM:
1207 case ReductionOperation::MEAN_SUM:
1208 {
1209 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1210 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1211
1212 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1213 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1214 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1215 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1216
1217 vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
1218 vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
1219 vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
1220 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
1221 break;
1222 }
1223 case ReductionOperation::PROD:
1224 {
1225 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1226 const auto scale32x4f_4 = wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
1227
1228 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1229 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1230
1231 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1232 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1233 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1234 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1235
1236 auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
1237 auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
1238 auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
1239 auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
1240
1241 //de-quantize vec_elements
1242 temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
1243 temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
1244 temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
1245 temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
1246
1247 vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
1248 vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
1249 vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
1250 vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
1251 break;
1252 }
1253 case ReductionOperation::ARG_IDX_MIN:
1254 {
1255 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1256 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1257 vec_res_value = temp_vec_res_value;
1258 break;
1259 }
1260 case ReductionOperation::ARG_IDX_MAX:
1261 {
1262 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1263 vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1264 vec_res_value = temp_vec_res_value;
1265 break;
1266 }
1267 case ReductionOperation::MIN:
1268 {
1269 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1270 break;
1271 }
1272 case ReductionOperation::MAX:
1273 {
1274 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1275 break;
1276 }
1277 default:
1278 ARM_COMPUTE_ERROR("Not supported");
1279 }
1280 }
1281
1282 switch(op)
1283 {
1284 case ReductionOperation::ARG_IDX_MIN:
1285 case ReductionOperation::ARG_IDX_MAX:
1286 {
1287 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]);
1288 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]);
1289 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]);
1290 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, vec_res_idx.val[3]);
1291 break;
1292 }
1293 case ReductionOperation::MIN:
1294 case ReductionOperation::MAX:
1295 {
1296 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value);
1297 break;
1298 }
1299 case ReductionOperation::SUM:
1300 {
1301 // Subtract offsets
1302 auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1303
1304 auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1305 auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1306 auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1307 auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
1308
1309 vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1310 vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1311 vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1312 vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1313
1314 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1315 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
1316
1317 combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x);
1318 break;
1319 }
1320 case ReductionOperation::MEAN_SUM:
1321 {
1322 const auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<float>(in_info.dimension(axis)), wrapper::traits::vector_128_tag{}));
1323 vec_res_value1_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value1), vec_width_inv);
1324 vec_res_value2_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value2), vec_width_inv);
1325 vec_res_value3_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value3), vec_width_inv);
1326 vec_res_value4_f = wrapper::vmul(wrapper::vcvt<float>(vec_res_value4), vec_width_inv);
1327
1328 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1329 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1330 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1331 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1332
1333 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1334 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1335 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1336
1337 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1338 break;
1339 }
1340 case ReductionOperation::PROD:
1341 {
1342 const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1343 const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
1344
1345 //re-quantize
1346 vec_res_value1_f = wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1347 vec_res_value2_f = wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1348 vec_res_value3_f = wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1349 vec_res_value4_f = wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1350
1351 vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1352 vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1353 vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1354 vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1355
1356 const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1357 const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1358 auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1359
1360 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1361 break;
1362 }
1363 default:
1364 ARM_COMPUTE_ERROR("Not supported");
1365 }
1366 }
1367
1368 // Compute left-over elements
1369 for(; x < window_end_x; ++x)
1370 {
1371 auto res_value = 0;
1372 switch(op)
1373 {
1374 case ReductionOperation::ARG_IDX_MAX:
1375 case ReductionOperation::ARG_IDX_MIN:
1376 case ReductionOperation::MIN:
1377 case ReductionOperation::MAX:
1378 {
1379 res_value = *(input_ptr + x);
1380 break;
1381 }
1382 case ReductionOperation::PROD:
1383 {
1384 res_value = static_cast<T>(1.0f);
1385 break;
1386 }
1387 default:
1388 {
1389 res_value = static_cast<T>(0.0f);
1390 break;
1391 }
1392 }
1393 uint32_t res_idx = 0;
1394
1395 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1396 {
1397 const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim);
1398 switch(op)
1399 {
1400 case ReductionOperation::SUM:
1401 case ReductionOperation::MEAN_SUM:
1402 {
1403 res_value += *in_ptr;
1404 break;
1405 }
1406 case ReductionOperation::SUM_SQUARE:
1407 {
1408 res_value += *in_ptr * *in_ptr;
1409 break;
1410 }
1411 case ReductionOperation::PROD:
1412 {
1413 //de-quantize input
1414 if(std::is_same<T, uint8_t>::value)
1415 {
1416 res_value *= dequantize_qasymm8(*input_ptr, iq_info);
1417 }
1418 else
1419 {
1420 res_value *= dequantize_qasymm8_signed(*input_ptr, iq_info);
1421 }
1422 break;
1423 }
1424 case ReductionOperation::ARG_IDX_MIN:
1425 {
1426 if(*in_ptr < res_value)
1427 {
1428 res_value = *in_ptr;
1429 res_idx = dim;
1430 }
1431 break;
1432 }
1433 case ReductionOperation::ARG_IDX_MAX:
1434 {
1435 if(*in_ptr > res_value)
1436 {
1437 res_value = *in_ptr;
1438 res_idx = dim;
1439 }
1440 break;
1441 }
1442 case ReductionOperation::MIN:
1443 {
1444 res_value = *in_ptr < res_value ? *in_ptr : res_value;
1445 break;
1446 }
1447 case ReductionOperation::MAX:
1448 {
1449 res_value = *in_ptr > res_value ? *in_ptr : res_value;
1450 break;
1451 }
1452 default:
1453 ARM_COMPUTE_ERROR("Not supported");
1454 }
1455 }
1456
1457 switch(op)
1458 {
1459 case ReductionOperation::MEAN_SUM:
1460 {
1461 res_value /= in_info.dimension(axis);
1462 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
1463 break;
1464 }
1465 case ReductionOperation::SUM:
1466 {
1467 // Subtract accumulated offsets
1468 res_value -= (in_info.dimension(axis) - 1) * iq_info.offset;
1469 *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
1470 break;
1471 }
1472 case ReductionOperation::PROD:
1473 {
1474 //re-quantize result
1475 if(std::is_same<T, uint8_t>::value)
1476 {
1477 res_value = quantize_qasymm8(res_value, iq_info);
1478 }
1479 else
1480 {
1481 res_value = quantize_qasymm8_signed(res_value, iq_info);
1482 }
1483 break;
1484 *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
1485 }
1486 case ReductionOperation::ARG_IDX_MIN:
1487 case ReductionOperation::ARG_IDX_MAX:
1488 {
1489 *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx;
1490 break;
1491 }
1492 default:
1493 *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
1494 }
1495 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001496 },
1497 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001498 }
1499};
1500
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001501void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001502{
giuros01154bc1c2019-03-26 17:44:40 +00001503 const bool is_complex = (input->info()->num_channels() == 2);
1504
1505 if(is_complex)
1506 {
1507 switch(axis)
1508 {
1509 case 2:
1510 switch(input->info()->data_type())
1511 {
1512 case DataType::F32:
1513 switch(op)
1514 {
1515 case ReductionOperation::SUM:
1516 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1517 default:
1518 ARM_COMPUTE_ERROR("Not supported");
1519 }
1520 default:
1521 ARM_COMPUTE_ERROR("Not supported");
1522 }
1523 default:
1524 ARM_COMPUTE_ERROR("Not supported");
1525 }
1526 }
1527
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001528 switch(axis)
1529 {
1530 case 0:
1531 switch(input->info()->data_type())
1532 {
1533 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001534 return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
1535 case DataType::QASYMM8_SIGNED:
1536 return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001537#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1538 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001539 return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001540#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1541 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001542 return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001543 case DataType::S32:
1544 return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001545 default:
1546 ARM_COMPUTE_ERROR("Not supported");
1547 }
1548 case 1:
1549 switch(input->info()->data_type())
1550 {
1551 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001552 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1553 case DataType::QASYMM8_SIGNED:
1554 return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001555#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1556 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001557 return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001558#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1559 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001560 return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001561 case DataType::S32:
1562 return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001563 default:
1564 ARM_COMPUTE_ERROR("Not supported");
1565 }
1566 case 2:
1567 switch(input->info()->data_type())
1568 {
1569 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001570 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1571 case DataType::QASYMM8_SIGNED:
1572 return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001573#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1574 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001575 return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001576#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1577 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001578 return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001579 case DataType::S32:
1580 return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001581 default:
1582 ARM_COMPUTE_ERROR("Not supported");
1583 }
1584 case 3:
1585 switch(input->info()->data_type())
1586 {
1587 case DataType::QASYMM8:
Luca Foschianiee939fb2020-01-28 10:38:07 +00001588 return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1589 case DataType::QASYMM8_SIGNED:
1590 return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001591#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1592 case DataType::F16:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001593 return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001594#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1595 case DataType::F32:
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001596 return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
Michalis Spyroub9626ab2019-05-13 17:41:01 +01001597 case DataType::S32:
1598 return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001599 default:
1600 ARM_COMPUTE_ERROR("Not supported");
1601 }
1602 default:
1603 ARM_COMPUTE_ERROR("Unsupported reduction axis");
1604 }
1605}
John Richardson73d4aef2018-05-08 14:34:33 +01001606
1607Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1608{
1609 ARM_COMPUTE_UNUSED(op);
1610
1611 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Georgios Pinitas8f5802f2019-02-22 11:08:32 +00001612 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
giuros01154bc1c2019-03-26 17:44:40 +00001613
1614 if(input->num_channels() == 1)
1615 {
Luca Foschianiee939fb2020-01-28 10:38:07 +00001616 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
giuros01154bc1c2019-03-26 17:44:40 +00001617 }
1618 else
1619 {
1620 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1621 ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1622 ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1623 }
John Richardson73d4aef2018-05-08 14:34:33 +01001624
1625 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001626 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +01001627
1628 if(output->total_size() != 0)
1629 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001630 bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1631 if(!is_arg_min_max)
1632 {
1633 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +00001634 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
giuros01154bc1c2019-03-26 17:44:40 +00001635 ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001636 }
1637 else
1638 {
Michele Di Giorgio9637b2e2019-09-23 16:49:49 +01001639 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001640 }
John Richardson73d4aef2018-05-08 14:34:33 +01001641
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001642 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
John Richardson73d4aef2018-05-08 14:34:33 +01001643 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1644 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1645 }
1646
1647 return Status{};
1648}
Georgios Pinitasd9769582017-08-03 10:19:40 +01001649} // namespace
1650
1651NEReductionOperationKernel::NEReductionOperationKernel()
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001652 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001653{
1654}
1655
Georgios Pinitasd9769582017-08-03 10:19:40 +01001656void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1657{
1658 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001659
John Richardson73d4aef2018-05-08 14:34:33 +01001660 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +01001661
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001662 _input = input;
1663 _output = output;
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001664 _op = op;
1665 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +01001666
1667 // Configure kernel window
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001668 Coordinates coord;
1669 coord.set_num_dimensions(input->info()->num_dimensions());
1670 input->info()->set_valid_region(ValidRegion(coord, input->info()->tensor_shape()));
1671 Window win = calculate_max_window(*input->info(), Steps(input->info()->dimension(0)));
1672 INEKernel::configure(win);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001673
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001674 // Calculate output shape and set if empty
1675 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis);
1676 // Output auto initialization if not yet initialized
1677 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1678 DataType output_data_type = is_arg_min_max ? DataType::S32 : input->info()->data_type();
1679 auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
1680 output->info()->set_valid_region(ValidRegion(coord, output_shape));
John Richardson73d4aef2018-05-08 14:34:33 +01001681}
1682
1683Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1684{
1685 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
John Richardson73d4aef2018-05-08 14:34:33 +01001686
1687 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +01001688}
1689
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001690void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +01001691{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001692 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001693 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1694 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1695
Michalis Spyrouaea14c62019-01-03 11:10:25 +00001696 reduce_op(window, _input, _output, _reduction_axis, _op);
Georgios Pinitasd9769582017-08-03 10:19:40 +01001697}
Michalis Spyroubcf8a962018-10-12 10:51:31 +01001698} // namespace arm_compute