blob: 9306e0303d3c722320564bcc867fc811858d043c [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
John Richardson73d4aef2018-05-08 14:34:33 +01002 * Copyright (c) 2017-2018 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEReductionOperationKernel.h"
25
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/INEKernel.h"
31#include "arm_compute/core/NEON/NEMath.h"
John Richardson73d4aef2018-05-08 14:34:33 +010032#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010033#include "arm_compute/core/Validate.h"
34
Michalis Spyroubcf8a962018-10-12 10:51:31 +010035#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036#include <arm_neon.h>
37
Michalis Spyroubcf8a962018-10-12 10:51:31 +010038namespace arm_compute
39{
Georgios Pinitasd9769582017-08-03 10:19:40 +010040namespace
41{
42template <class F>
43class Reducer
44{
45public:
46 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f)
47 {
48 // Set out window
49 Window out_window(window);
50 out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
51
52 // Get first input and output slices
53 Window in_slice = window.first_slice_window_1D();
54 Window out_slice = out_window.first_slice_window_1D();
55
56 do
57 {
58 Iterator in(input, in_slice);
59 Iterator out(output, out_slice);
60
Michalis Spyroubcf8a962018-10-12 10:51:31 +010061 f(in, out, in_slice, out_slice, *input->info());
Georgios Pinitasd9769582017-08-03 10:19:40 +010062 }
Michalis Spyroubcf8a962018-10-12 10:51:31 +010063 while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
64 }
65 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f)
66 {
67 // Set in window
68 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +000069 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +010070
71 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +000072 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +010073
74 // Get first input and output slices
75 Window in_slice = in_window.first_slice_window_2D();
Michalis Spyrou2897e612018-11-20 18:38:29 +000076 Window out_slice = out_window.first_slice_window_2D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +010077
78 do
79 {
80 Iterator in(input, in_slice);
81 Iterator out(output, out_slice);
82
83 f(in, out, in_slice, out_slice, *input->info(), 1);
84 }
Michalis Spyrou2897e612018-11-20 18:38:29 +000085 while(in_window.slide_window_slice_2D(in_slice) && out_window.slide_window_slice_2D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +010086 }
87 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f)
88 {
89 // Set in window
90 Window in_window(window);
Michalis Spyrou2897e612018-11-20 18:38:29 +000091 Window out_window(window);
Michalis Spyroubcf8a962018-10-12 10:51:31 +010092
93 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +000094 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
Michalis Spyroubcf8a962018-10-12 10:51:31 +010095
96 // Get first input and output slices
97 Window in_slice = in_window.first_slice_window_3D();
Michalis Spyrou2897e612018-11-20 18:38:29 +000098 Window out_slice = out_window.first_slice_window_3D();
Michalis Spyroubcf8a962018-10-12 10:51:31 +010099
100 do
101 {
102 Iterator in(input, in_slice);
103 Iterator out(output, out_slice);
104
105 f(in, out, in_slice, out_slice, *input->info(), 2);
106 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000107 while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100108 }
109 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f)
110 {
111 // Set in/out window
112 Window in_window(window);
113 Window out_window(window);
114
115 in_window.set(3, Window::Dimension(0, 1, 1));
116 out_window.set(3, Window::Dimension(0, 1, 1));
117
118 // Get first input and output slices
119 Window in_slice = in_window.first_slice_window_4D();
120 Window out_slice = out_window.first_slice_window_4D();
121
122 do
123 {
124 Iterator in(input, in_slice);
125 Iterator out(output, out_slice);
126
127 f(in, out, in_slice, out_slice, *input->info(), 3);
128 }
129 while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100130 }
131};
132
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100133template <typename T, int S, ReductionOperation op>
134struct RedOpX
Georgios Pinitasd9769582017-08-03 10:19:40 +0100135{
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100136 /** NEON vector tag type. */
137 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
138
139 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100140 {
141 ARM_COMPUTE_UNUSED(out_slice);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100142 auto vec_sum_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
Georgios Pinitasd9769582017-08-03 10:19:40 +0100143
144 execute_window_loop(in_slice, [&](const Coordinates & id)
145 {
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100146 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
147 const auto vec_elements = wrapper::vloadq(in_ptr);
148
149 if(op == ReductionOperation::SUM_SQUARE)
150 {
151 vec_sum_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_sum_value);
152 }
153 else
154 {
155 vec_sum_value = wrapper::vadd(vec_elements, vec_sum_value);
156 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100157 },
158 input);
159
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100160 auto carry_addition = wrapper::vpadd(wrapper::vgethigh(vec_sum_value), wrapper::vgetlow(vec_sum_value));
Michele Di Giorgio1c948d42018-11-20 16:03:01 +0000161 for(int i = 0; i < S / 4; ++i)
162 {
163 carry_addition = wrapper::vpadd(carry_addition, carry_addition);
164 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100165
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100166 auto res = wrapper::vgetlane(carry_addition, 0);
167 if(op == ReductionOperation::MEAN_SUM)
168 {
169 res /= in_info.dimension(0);
170 }
171
172 *(reinterpret_cast<T *>(output.ptr())) = res;
173 }
174};
175
176template <ReductionOperation op>
177struct RedOpX_qasymm8
178{
179 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info)
180 {
181 ARM_COMPUTE_UNUSED(out_slice);
182 auto vec_sum_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
183 auto vec_sum_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
184 auto vec_sum_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
185 auto vec_sum_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
186
187 execute_window_loop(in_slice, [&](const Coordinates & id)
188 {
189 const auto vec_elements = wrapper::vloadq(input.ptr());
190
191 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
192 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
193
194 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
195 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
196 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
197 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
198
199 vec_sum_value1 = wrapper::vadd(temp32x4t_1, vec_sum_value1);
200 vec_sum_value2 = wrapper::vadd(temp32x4t_2, vec_sum_value2);
201 vec_sum_value3 = wrapper::vadd(temp32x4t_3, vec_sum_value3);
202 vec_sum_value4 = wrapper::vadd(temp32x4t_4, vec_sum_value4);
203 },
204 input);
205
206 auto carry_addition = wrapper::vadd(vec_sum_value1, vec_sum_value2);
207 carry_addition = wrapper::vadd(carry_addition, vec_sum_value3);
208 carry_addition = wrapper::vadd(carry_addition, vec_sum_value4);
209
210 auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_addition), wrapper::vgetlow(carry_addition));
211 carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
212 auto res = wrapper::vgetlane(carry_paddition, 0);
213
214 if(op == ReductionOperation::MEAN_SUM)
215 {
216 res /= in_info.dimension(0);
217 }
218
219 *(output.ptr()) = static_cast<uint8_t>(res);
220 }
221};
222
223template <typename T, int S, ReductionOperation op>
224struct RedOpYZW
225{
226 /** NEON vector tag type. */
227 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
228
229 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis)
230 {
231 ARM_COMPUTE_UNUSED(out_slice);
232
233 execute_window_loop(in_slice, [&](const Coordinates & id)
234 {
235 auto vec_sum_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
236 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
237 {
238 T *in_ptr;
239 switch(axis)
240 {
241 case 1:
242 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
243 break;
244 case 2:
245 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
246 break;
247 case 3:
248 in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
249 break;
250 default:
251 ARM_COMPUTE_ERROR("Not supported");
252 }
253 const auto vec_elements = wrapper::vloadq(in_ptr);
254
255 if(op == ReductionOperation::SUM_SQUARE)
256 {
257 vec_sum_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_sum_value);
258 }
259 else
260 {
261 vec_sum_value = wrapper::vadd(vec_elements, vec_sum_value);
262 }
263 }
264
265 if(op == ReductionOperation::MEAN_SUM)
266 {
267 auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
268 vec_sum_value = wrapper::vmul(vec_sum_value, vec_width_inv);
269 }
270
271 wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_sum_value);
272 },
273 input, output);
274 }
275};
276
277template <ReductionOperation op>
278struct RedOpYZW_qasymm8
279{
280 inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis)
281 {
282 ARM_COMPUTE_UNUSED(out_slice);
283
284 execute_window_loop(in_slice, [&](const Coordinates & id)
285 {
286 auto vec_sum_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
287 auto vec_sum_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
288 auto vec_sum_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
289 auto vec_sum_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
290 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
291 {
292 uint8_t *in_ptr;
293 switch(axis)
294 {
295 case 1:
296 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim));
297 break;
298 case 2:
299 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim));
300 break;
301 case 3:
302 in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim));
303 break;
304 default:
305 ARM_COMPUTE_ERROR("Not supported");
306 }
307 const auto vec_elements = wrapper::vloadq(in_ptr);
308
309 const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
310 const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
311
312 const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
313 const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
314 const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
315 const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
316
317 vec_sum_value1 = wrapper::vadd(temp32x4t_1, vec_sum_value1);
318 vec_sum_value2 = wrapper::vadd(temp32x4t_2, vec_sum_value2);
319 vec_sum_value3 = wrapper::vadd(temp32x4t_3, vec_sum_value3);
320 vec_sum_value4 = wrapper::vadd(temp32x4t_4, vec_sum_value4);
321 }
322
323 if(op == ReductionOperation::MEAN_SUM)
324 {
325 const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
326 const auto vec_sum_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value1), vec_width_inv);
327 const auto vec_sum_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value2), vec_width_inv);
328 const auto vec_sum_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value3), vec_width_inv);
329 const auto vec_sum_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value4), vec_width_inv);
330
331 vec_sum_value1 = vcvtq_u32_f32(vec_sum_value1_f);
332 vec_sum_value2 = vcvtq_u32_f32(vec_sum_value2_f);
333 vec_sum_value3 = vcvtq_u32_f32(vec_sum_value3_f);
334 vec_sum_value4 = vcvtq_u32_f32(vec_sum_value4_f);
335 }
336
337 const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_sum_value1), wrapper::vqmovn(vec_sum_value2));
338 const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_sum_value3), wrapper::vqmovn(vec_sum_value4));
339 auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
340 wrapper::vstore(output.ptr(), res);
341 },
342 input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100343 }
344};
345
346void reduce_sumsq(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
347{
348 switch(axis)
349 {
350 case 0:
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100351 switch(input->info()->data_type())
352 {
353#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
354 case DataType::F16:
355 return Reducer<RedOpX<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::SUM_SQUARE>());
356#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
357 case DataType::F32:
358 return Reducer<RedOpX<float, 4, ReductionOperation::SUM_SQUARE>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::SUM_SQUARE>());
359 case DataType::QASYMM8:
360 default:
361 ARM_COMPUTE_ERROR("Not supported");
362 }
363 case 1:
364 switch(input->info()->data_type())
365 {
366#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
367 case DataType::F16:
368 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
369#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
370 case DataType::F32:
371 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
372 case DataType::QASYMM8:
373 default:
374 ARM_COMPUTE_ERROR("Not supported");
375 }
376 case 2:
377 switch(input->info()->data_type())
378 {
379#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
380 case DataType::F16:
381 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
382#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
383 case DataType::F32:
384 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
385 case DataType::QASYMM8:
386 default:
387 ARM_COMPUTE_ERROR("Not supported");
388 }
389 case 3:
390 switch(input->info()->data_type())
391 {
392#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
393 case DataType::F16:
394 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
395#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
396 case DataType::F32:
397 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
398 case DataType::QASYMM8:
399 default:
400 ARM_COMPUTE_ERROR("Not supported");
401 }
402 default:
403 ARM_COMPUTE_ERROR("Unsupported reduction axis");
404 }
405}
406
407void reduce_sum(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
408{
409 switch(axis)
410 {
411 case 0:
412 switch(input->info()->data_type())
413 {
414 case DataType::QASYMM8:
415 return Reducer<RedOpX_qasymm8<ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX_qasymm8<ReductionOperation::SUM>());
416#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
417 case DataType::F16:
418 return Reducer<RedOpX<float16_t, 8, ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::SUM>());
419#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
420 case DataType::F32:
421 return Reducer<RedOpX<float, 4, ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::SUM>());
422 default:
423 ARM_COMPUTE_ERROR("Not supported");
424 }
425 case 1:
426 switch(input->info()->data_type())
427 {
428 case DataType::QASYMM8:
429 return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
430#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
431 case DataType::F16:
432 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
433#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
434 case DataType::F32:
435 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
436 default:
437 ARM_COMPUTE_ERROR("Not supported");
438 }
439 case 2:
440 switch(input->info()->data_type())
441 {
442 case DataType::QASYMM8:
443 return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
444#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
445 case DataType::F16:
446 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
447#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
448 case DataType::F32:
449 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
450 default:
451 ARM_COMPUTE_ERROR("Not supported");
452 }
453 case 3:
454 switch(input->info()->data_type())
455 {
456 case DataType::QASYMM8:
457 return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
458#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
459 case DataType::F16:
460 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
461#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
462 case DataType::F32:
463 return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
464 default:
465 ARM_COMPUTE_ERROR("Not supported");
466 }
467 default:
468 ARM_COMPUTE_ERROR("Unsupported reduction axis");
469 }
470}
471void reduce_mean_sum(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
472{
473 switch(axis)
474 {
475 case 0:
476 switch(input->info()->data_type())
477 {
478 case DataType::QASYMM8:
479 return Reducer<RedOpX_qasymm8<ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX_qasymm8<ReductionOperation::MEAN_SUM>());
480#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
481 case DataType::F16:
482 return Reducer<RedOpX<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::MEAN_SUM>());
483#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
484 case DataType::F32:
485 return Reducer<RedOpX<float, 4, ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::MEAN_SUM>());
486 default:
487 ARM_COMPUTE_ERROR("Not supported");
488 }
489 case 1:
490 switch(input->info()->data_type())
491 {
492 case DataType::QASYMM8:
493 return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
494#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
495 case DataType::F16:
496 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
497#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
498 case DataType::F32:
499 return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
500 default:
501 ARM_COMPUTE_ERROR("Not supported");
502 }
503 case 2:
504 switch(input->info()->data_type())
505 {
506 case DataType::QASYMM8:
507 return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
508#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
509 case DataType::F16:
510 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
511#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
512 case DataType::F32:
513 return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
514 default:
515 ARM_COMPUTE_ERROR("Not supported");
516 }
517 case 3:
518 switch(input->info()->data_type())
519 {
520 case DataType::QASYMM8:
521 return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
522#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
523 case DataType::F16:
524 return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
525#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
526 case DataType::F32:
527 return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
528 default:
529 ARM_COMPUTE_ERROR("Not supported");
530 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100531 default:
532 ARM_COMPUTE_ERROR("Unsupported reduction axis");
533 }
534}
John Richardson73d4aef2018-05-08 14:34:33 +0100535
536TensorShape calculate_output_shape(const TensorShape &input_shape, unsigned int axis)
537{
538 TensorShape output_shape{ input_shape };
539 output_shape.set(axis, 1);
540
541 return output_shape;
542}
543
544Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
545{
546 ARM_COMPUTE_UNUSED(op);
547
548 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100549 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
John Richardson73d4aef2018-05-08 14:34:33 +0100550
551 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100552 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
John Richardson73d4aef2018-05-08 14:34:33 +0100553
554 if(output->total_size() != 0)
555 {
556 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100557 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100558
559 const TensorShape output_shape = calculate_output_shape(input->tensor_shape(), axis);
560 const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
561 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
562 }
563
564 return Status{};
565}
566
567std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis)
568{
569 // Calculate output shape and set if empty
570 const TensorShape output_shape = calculate_output_shape(input->tensor_shape(), axis);
571
572 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100573 auto_init_if_empty(*output, output_shape, 1, input->data_type());
John Richardson73d4aef2018-05-08 14:34:33 +0100574
575 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
576
577 // Configure kernel window
578 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
579 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
580 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
581
582 bool window_changed = update_window_and_padding(win, input_access, output_access);
583 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
584
585 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
586
587 return std::make_tuple(err, win);
588}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100589} // namespace
590
591NEReductionOperationKernel::NEReductionOperationKernel()
592 : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE), _border_size()
593{
594}
595
596BorderSize NEReductionOperationKernel::border_size() const
597{
598 return _border_size;
599}
600
601void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
602{
603 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100604
John Richardson73d4aef2018-05-08 14:34:33 +0100605 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100606
607 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
608
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100609 _input = input;
610 _output = output;
611 _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
612 _op = op;
613 _reduction_axis = axis;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100614
615 // Configure kernel window
John Richardson73d4aef2018-05-08 14:34:33 +0100616 auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100617
John Richardson73d4aef2018-05-08 14:34:33 +0100618 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100619
John Richardson73d4aef2018-05-08 14:34:33 +0100620 INEKernel::configure(std::get<1>(win_config));
621}
622
623Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
624{
625 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
626 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis)));
627
628 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +0100629}
630
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100631void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100632{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100633 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100634 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
635 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
636
637 switch(_op)
638 {
639 case ReductionOperation::SUM_SQUARE:
640 reduce_sumsq(window, _input, _output, _reduction_axis);
641 break;
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100642 case ReductionOperation::MEAN_SUM:
643 reduce_mean_sum(window, _input, _output, _reduction_axis);
644 break;
645 case ReductionOperation::SUM:
646 reduce_sum(window, _input, _output, _reduction_axis);
647 break;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100648 default:
649 ARM_COMPUTE_ERROR("Unsupported reduction operation.");
650 }
651}
Michalis Spyroubcf8a962018-10-12 10:51:31 +0100652} // namespace arm_compute