blob: b77219cd79f9b8dd19ecacc7c8cbd1da11c56bc0 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
John Richardson73d4aef2018-05-08 14:34:33 +01002 * Copyright (c) 2017-2018 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/INEKernel.h"
31#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010032#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010033#include "arm_compute/core/Validate.h"
34
Michalis Spyroubcf8a962018-10-12 10:51:31 +010035#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036#include <arm_neon.h>
37
Michalis Spyroubcf8a962018-10-12 10:51:31 +010038namespace arm_compute
39{
Georgios Pinitasd9769582017-08-03 10:19:40 +010040namespace
41{
42template <class F>
43class Reducer
44{
45public:
46 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f)
47 {
48 // Set out window
49 Window out_window(window);
50 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
51
52 // Get first input and output slices
53 Window in_slice = window.first_slice_window_1D();
54 Window out_slice = out_window.first_slice_window_1D();
55
56 do
57 {
58 Iterator in(input, in_slice);
59 Iterator out(output, out_slice);
60
Michalis Spyroubcf8a962018-10-12 10:51:31 +010061 f(in, out, in_slice, out_slice, *input->info());
Georgios Pinitasd9769582017-08-03 10:19:40 +010062 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +010063 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
64 }
65 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f)
66 {
67 // Set in window
68 Window in_window(window);
69
70 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
71
72 // Get first input and output slices
73 Window in_slice = in_window.first_slice_window_2D();
74 Window out_slice = window.first_slice_window_2D();
75
76 do
77 {
78 Iterator in(input, in_slice);
79 Iterator out(output, out_slice);
80
81 f(in, out, in_slice, out_slice, *input->info(), 1);
82 }
83 while(in_window.slide_window_slice_2D(in_slice) && window.slide_window_slice_2D(out_slice));
84 }
85 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f)
86 {
87 // Set in window
88 Window in_window(window);
89
90 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
91
92 // Get first input and output slices
93 Window in_slice = in_window.first_slice_window_3D();
94 Window out_slice = window.first_slice_window_3D();
95
96 do
97 {
98 Iterator in(input, in_slice);
99 Iterator out(output, out_slice);
100
101 f(in, out, in_slice, out_slice, *input->info(), 2);
102 }
103 while(in_window.slide_window_slice_3D(in_slice) && window.slide_window_slice_3D(out_slice));
104 }
105 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f)
106 {
107 // Set in/out window
108 Window in_window(window);
109 Window out_window(window);
110
111 in_window.set(3, Window::Dimension(0, 1, 1));
112 out_window.set(3, Window::Dimension(0, 1, 1));
113
114 // Get first input and output slices
115 Window in_slice = in_window.first_slice_window_4D();
116 Window out_slice = out_window.first_slice_window_4D();
117
118 do
119 {
120 Iterator in(input, in_slice);
121 Iterator out(output, out_slice);
122
123 f(in, out, in_slice, out_slice, *input->info(), 3);
124 }
125 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100126 }
127};
128
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100129template <typename T, int S, ReductionOperation op>
130struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100131{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100132 /** NEON vector tag type. */
133 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
134
135 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100136 {
137 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100138 auto vec_sum_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
Georgios Pinitasd9769582017-08-03 10:19:40 +0100139
140 execute_window_loop(in_slice, [&](const Coordinates & id)
141 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100142 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
143 const auto vec_elements = wrapper::vloadq(in_ptr);
144
145 if(op == ReductionOperation::SUM_SQUARE)
146 {
147 vec_sum_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_sum_value);
148 }
149 else
150 {
151 vec_sum_value = wrapper::vadd(vec_elements, vec_sum_value);
152 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100153 },
154 input);
155
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100156 auto carry_addition = wrapper::vpadd(wrapper::vgethigh(vec_sum_value), wrapper::vgetlow(vec_sum_value));
157 carry_addition = wrapper::vpadd(carry_addition, carry_addition);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100158
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100159 auto res = wrapper::vgetlane(carry_addition, 0);
160 if(op == ReductionOperation::MEAN_SUM)
161 {
162 res /= in_info.dimension(0);
163 }
164
165 *(reinterpret_cast<T *>(output.ptr())) = res;
166 }
167};
168
169template <ReductionOperation op>
170struct RedOpX_qasymm8
171{
172 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info)
173 {
174 ARM_COMPUTE_UNUSED(out_slice);
175 auto vec_sum_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
176 auto vec_sum_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
177 auto vec_sum_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
178 auto vec_sum_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
179
180 execute_window_loop(in_slice, [&](const Coordinates & id)
181 {
182 const auto vec_elements = wrapper::vloadq(input.ptr());
183
184 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
185 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
186
187 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
188 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
189 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
190 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
191
192 vec_sum_value1 = wrapper::vadd(temp32x4t_1, vec_sum_value1);
193 vec_sum_value2 = wrapper::vadd(temp32x4t_2, vec_sum_value2);
194 vec_sum_value3 = wrapper::vadd(temp32x4t_3, vec_sum_value3);
195 vec_sum_value4 = wrapper::vadd(temp32x4t_4, vec_sum_value4);
196 },
197 input);
198
199 auto carry_addition = wrapper::vadd(vec_sum_value1, vec_sum_value2);
200 carry_addition = wrapper::vadd(carry_addition, vec_sum_value3);
201 carry_addition = wrapper::vadd(carry_addition, vec_sum_value4);
202
203 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_addition), wrapper::vgetlow(carry_addition));
204 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
205 auto res = wrapper::vgetlane(carry_paddition, 0);
206
207 if(op == ReductionOperation::MEAN_SUM)
208 {
209 res /= in_info.dimension(0);
210 }
211
212 *(output.ptr()) = static_cast<uint8_t>(res);
213 }
214};
215
216template <typename T, int S, ReductionOperation op>
217struct RedOpYZW
218{
219 /** NEON vector tag type. */
220 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
221
222 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis)
223 {
224 ARM_COMPUTE_UNUSED(out_slice);
225
226 execute_window_loop(in_slice, [&](const Coordinates & id)
227 {
228 auto vec_sum_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
229 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
230 {
231 T *in_ptr;
232 switch(axis)
233 {
234 case 1:
235 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
236 break;
237 case 2:
238 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
239 break;
240 case 3:
241 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
242 break;
243 default:
244 ARM_COMPUTE_ERROR("Not supported");
245 }
246 const auto vec_elements = wrapper::vloadq(in_ptr);
247
248 if(op == ReductionOperation::SUM_SQUARE)
249 {
250 vec_sum_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_sum_value);
251 }
252 else
253 {
254 vec_sum_value = wrapper::vadd(vec_elements, vec_sum_value);
255 }
256 }
257
258 if(op == ReductionOperation::MEAN_SUM)
259 {
260 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
261 vec_sum_value = wrapper::vmul(vec_sum_value, vec_width_inv);
262 }
263
264 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_sum_value);
265 },
266 input, output);
267 }
268};
269
270template <ReductionOperation op>
271struct RedOpYZW_qasymm8
272{
273 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis)
274 {
275 ARM_COMPUTE_UNUSED(out_slice);
276
277 execute_window_loop(in_slice, [&](const Coordinates & id)
278 {
279 auto vec_sum_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
280 auto vec_sum_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
281 auto vec_sum_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
282 auto vec_sum_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
283 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
284 {
285 uint8_t *in_ptr;
286 switch(axis)
287 {
288 case 1:
289 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim));
290 break;
291 case 2:
292 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim));
293 break;
294 case 3:
295 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim));
296 break;
297 default:
298 ARM_COMPUTE_ERROR("Not supported");
299 }
300 const auto vec_elements = wrapper::vloadq(in_ptr);
301
302 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
303 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
304
305 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
306 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
307 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
308 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
309
310 vec_sum_value1 = wrapper::vadd(temp32x4t_1, vec_sum_value1);
311 vec_sum_value2 = wrapper::vadd(temp32x4t_2, vec_sum_value2);
312 vec_sum_value3 = wrapper::vadd(temp32x4t_3, vec_sum_value3);
313 vec_sum_value4 = wrapper::vadd(temp32x4t_4, vec_sum_value4);
314 }
315
316 if(op == ReductionOperation::MEAN_SUM)
317 {
318 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
319 const auto vec_sum_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value1), vec_width_inv);
320 const auto vec_sum_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value2), vec_width_inv);
321 const auto vec_sum_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value3), vec_width_inv);
322 const auto vec_sum_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value4), vec_width_inv);
323
324 vec_sum_value1 = vcvtq_u32_f32(vec_sum_value1_f);
325 vec_sum_value2 = vcvtq_u32_f32(vec_sum_value2_f);
326 vec_sum_value3 = vcvtq_u32_f32(vec_sum_value3_f);
327 vec_sum_value4 = vcvtq_u32_f32(vec_sum_value4_f);
328 }
329
330 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_sum_value1), wrapper::vqmovn(vec_sum_value2));
331 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_sum_value3), wrapper::vqmovn(vec_sum_value4));
332 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
333 wrapper::vstore(output.ptr(), res);
334 },
335 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100336 }
337};
338
339void reduce_sumsq(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
340{
341 switch(axis)
342 {
343 case 0:
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100344 switch(input->info()->data_type())
345 {
346#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
347 case DataType::F16:
348 return Reducer<RedOpX<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::SUM_SQUARE>());
349#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
350 case DataType::F32:
351 return Reducer<RedOpX<float, 4, ReductionOperation::SUM_SQUARE>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::SUM_SQUARE>());
352 case DataType::QASYMM8:
353 default:
354 ARM_COMPUTE_ERROR("Not supported");
355 }
356 case 1:
357 switch(input->info()->data_type())
358 {
359#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
360 case DataType::F16:
361 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
362#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
363 case DataType::F32:
364 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
365 case DataType::QASYMM8:
366 default:
367 ARM_COMPUTE_ERROR("Not supported");
368 }
369 case 2:
370 switch(input->info()->data_type())
371 {
372#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
373 case DataType::F16:
374 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
375#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
376 case DataType::F32:
377 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
378 case DataType::QASYMM8:
379 default:
380 ARM_COMPUTE_ERROR("Not supported");
381 }
382 case 3:
383 switch(input->info()->data_type())
384 {
385#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
386 case DataType::F16:
387 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
388#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
389 case DataType::F32:
390 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
391 case DataType::QASYMM8:
392 default:
393 ARM_COMPUTE_ERROR("Not supported");
394 }
395 default:
396 ARM_COMPUTE_ERROR("Unsupported reduction axis");
397 }
398}
399
400void reduce_sum(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
401{
402 switch(axis)
403 {
404 case 0:
405 switch(input->info()->data_type())
406 {
407 case DataType::QASYMM8:
408 return Reducer<RedOpX_qasymm8<ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX_qasymm8<ReductionOperation::SUM>());
409#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
410 case DataType::F16:
411 return Reducer<RedOpX<float16_t, 8, ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::SUM>());
412#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
413 case DataType::F32:
414 return Reducer<RedOpX<float, 4, ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::SUM>());
415 default:
416 ARM_COMPUTE_ERROR("Not supported");
417 }
418 case 1:
419 switch(input->info()->data_type())
420 {
421 case DataType::QASYMM8:
422 return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
423#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
424 case DataType::F16:
425 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
426#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
427 case DataType::F32:
428 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
429 default:
430 ARM_COMPUTE_ERROR("Not supported");
431 }
432 case 2:
433 switch(input->info()->data_type())
434 {
435 case DataType::QASYMM8:
436 return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
437#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
438 case DataType::F16:
439 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
440#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
441 case DataType::F32:
442 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
443 default:
444 ARM_COMPUTE_ERROR("Not supported");
445 }
446 case 3:
447 switch(input->info()->data_type())
448 {
449 case DataType::QASYMM8:
450 return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
451#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
452 case DataType::F16:
453 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
454#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
455 case DataType::F32:
456 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
457 default:
458 ARM_COMPUTE_ERROR("Not supported");
459 }
460 default:
461 ARM_COMPUTE_ERROR("Unsupported reduction axis");
462 }
463}
464void reduce_mean_sum(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
465{
466 switch(axis)
467 {
468 case 0:
469 switch(input->info()->data_type())
470 {
471 case DataType::QASYMM8:
472 return Reducer<RedOpX_qasymm8<ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX_qasymm8<ReductionOperation::MEAN_SUM>());
473#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
474 case DataType::F16:
475 return Reducer<RedOpX<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::MEAN_SUM>());
476#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
477 case DataType::F32:
478 return Reducer<RedOpX<float, 4, ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::MEAN_SUM>());
479 default:
480 ARM_COMPUTE_ERROR("Not supported");
481 }
482 case 1:
483 switch(input->info()->data_type())
484 {
485 case DataType::QASYMM8:
486 return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
487#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
488 case DataType::F16:
489 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
490#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
491 case DataType::F32:
492 return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
493 default:
494 ARM_COMPUTE_ERROR("Not supported");
495 }
496 case 2:
497 switch(input->info()->data_type())
498 {
499 case DataType::QASYMM8:
500 return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
501#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
502 case DataType::F16:
503 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
504#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
505 case DataType::F32:
506 return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
507 default:
508 ARM_COMPUTE_ERROR("Not supported");
509 }
510 case 3:
511 switch(input->info()->data_type())
512 {
513 case DataType::QASYMM8:
514 return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
515#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
516 case DataType::F16:
517 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
518#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
519 case DataType::F32:
520 return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
521 default:
522 ARM_COMPUTE_ERROR("Not supported");
523 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100524 default:
525 ARM_COMPUTE_ERROR("Unsupported reduction axis");
526 }
527}
John Richardson73d4aef2018-05-08 14:34:33 +0100528
529TensorShape calculate_output_shape(const TensorShape &input_shape, unsigned int axis)
530{
531 TensorShape output_shape{ input_shape };
532 output_shape.set(axis, 1);
533
534 return output_shape;
535}
536
537Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
538{
539 ARM_COMPUTE_UNUSED(op);
540
541 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100542 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
John Richardson73d4aef2018-05-08 14:34:33 +0100543
544 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100545 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +0100546
547 if(output->total_size() != 0)
548 {
549 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100550 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100551
552 const TensorShape output_shape = calculate_output_shape(input->tensor_shape(), axis);
553 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
554 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
555 }
556
557 return Status{};
558}
559
560std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis)
561{
562 // Calculate output shape and set if empty
563 const TensorShape output_shape = calculate_output_shape(input->tensor_shape(), axis);
564
565 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100566 auto_init_if_empty(*output, output_shape, 1, input->data_type());
John Richardson73d4aef2018-05-08 14:34:33 +0100567
568 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
569
570 // Configure kernel window
571 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
572 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
573 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
574
575 bool window_changed = update_window_and_padding(win, input_access, output_access);
576 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
577
578 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
579
580 return std::make_tuple(err, win);
581}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100582} // namespace
583
584NEReductionOperationKernel::NEReductionOperationKernel()
585 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
586{
587}
588
589BorderSize NEReductionOperationKernel::border_size() const
590{
591 return _border_size;
592}
593
594void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
595{
596 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100597
John Richardson73d4aef2018-05-08 14:34:33 +0100598 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100599
600 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
601
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100602 _input = input;
603 _output = output;
604 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
605 _op = op;
606 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100607
608 // Configure kernel window
John Richardson73d4aef2018-05-08 14:34:33 +0100609 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100610
John Richardson73d4aef2018-05-08 14:34:33 +0100611 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100612
John Richardson73d4aef2018-05-08 14:34:33 +0100613 INEKernel::configure(std::get<1>(win_config));
614}
615
616Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
617{
618 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
619 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis)));
620
621 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +0100622}
623
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100624void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100625{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100626 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100627 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
628 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
629
630 switch(_op)
631 {
632 case ReductionOperation::SUM_SQUARE:
633 reduce_sumsq(window, _input, _output, _reduction_axis);
634 break;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100635 case ReductionOperation::MEAN_SUM:
636 reduce_mean_sum(window, _input, _output, _reduction_axis);
637 break;
638 case ReductionOperation::SUM:
639 reduce_sum(window, _input, _output, _reduction_axis);
640 break;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100641 default:
642 ARM_COMPUTE_ERROR("Unsupported reduction operation.");
643 }
644}
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100645} // namespace arm_compute