blob: c7ca36d5e84f9f13f8c2659f77b90bfdecdb9e26 [file] [log] [blame]
Ramy Elgammal21fb2ad2024-05-13 11:12:11 +01001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H
25#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H
26
27#include "arm_compute/core/Coordinates.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/TensorInfo.h"
31
32#include "src/core/NEON/NEMath.h"
33#include "src/core/NEON/wrapper/wrapper.h"
34#include "support/SaturateCast.h"
35
36#include <arm_neon.h>
37
38namespace arm_compute
39{
40// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
41void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
42{
43 auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
44 wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
45}
46
47#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
48uint32x4x4_t
49calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
50{
51 uint32x4x2_t mask{0};
52 uint16x8_t mask_u16{0};
53 if (op == ReductionOperation::ARG_IDX_MIN)
54 {
55 mask_u16 = wrapper::vcgt(b, a);
56 }
57 else
58 {
59 mask_u16 = wrapper::vclt(b, a);
60 }
61 mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
62 mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
63 uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}};
64 if (axis != 0)
65 {
66 vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
67 vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
68 }
69 uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
70 wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0};
71
72 return res;
73}
74
75// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
76inline float16x4_t calculate_min(float16x8_t in)
77{
78 auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
79 pmin = wrapper::vpmin(pmin, pmin);
80 return wrapper::vpmin(pmin, pmin);
81}
82// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
83inline float16x4_t calculate_max(float16x8_t in)
84{
85 auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
86 pmax = wrapper::vpmax(pmax, pmax);
87 return wrapper::vpmax(pmax, pmax);
88}
89
90uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
91{
92 uint32x4x2_t res_idx_mask{0};
93 uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
94 uint16x8_t mask_u16;
95 if (op == ReductionOperation::ARG_IDX_MIN)
96 {
97 auto pmin = calculate_min(vec_res_value);
98 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
99 }
100 else
101 {
102 auto pmax = calculate_max(vec_res_value);
103 mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
104 }
105
106 // Widen vectors
107 auto wide_u32_1 =
108 wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
109 auto wide_u32_2 =
110 wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
111 res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
112 res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
113 res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
114 res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
115
116 uint32_t res = 0xFFFFFFFF;
117 uint32_t iter = 0;
118 do
119 {
120 auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
121 pmin = wrapper::vpmin(pmin, pmin);
122 res = std::min(wrapper::vgetlane(pmin, 0), res);
123 iter++;
124 } while (iter < 2);
125
126 return (res - 0xFFFFFFFF);
127}
128#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
129
130template <class F>
131class Reducer
132{
133public:
134 static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
135 {
136 // Set out window
137 Window out_window(window);
138 out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
139
140 f(window, out_window, input, output, op);
141 }
142 static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
143 {
144 // Set in window
145 Window in_window(window);
146 Window out_window(window);
147
148 in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
149 out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
150
151 f(in_window, out_window, input, output, 1, op);
152 }
153 static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
154 {
155 // Set in window
156 Window in_window(window);
157 Window out_window(window);
158
159 in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
160 out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
161
162 f(in_window, out_window, input, output, 2, op);
163 }
164 static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
165 {
166 // Set in/out window
167 Window in_window(window);
168 Window out_window(window);
169
170 in_window.set(3, Window::Dimension(0, 1, 1));
171 out_window.set(3, Window::Dimension(0, 1, 1));
172
173 f(in_window, out_window, input, output, 3, op);
174 }
175};
176
177template <typename T, int S>
178struct RedOpX
179{
180 /** SIMD vector tag type. */
181 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
182
183 inline void operator()(
184 const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
185 {
186 const size_t input_dim_0 = in->info()->dimension(0);
187 const int window_step_x = 16 / sizeof(T);
188 const auto window_start_x = static_cast<int>(in_window.x().start());
189 const auto window_end_x = static_cast<int>(in_window.x().end());
190
191 Window in_win_no_pad = in_window;
192 in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
193
194 Iterator input(in, in_win_no_pad);
195 Iterator output(out, out_window);
196
197 execute_window_loop(
198 in_win_no_pad,
199 [&](const Coordinates &)
200 {
201 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
202
203 auto init_res_value = static_cast<T>(0.f);
204 switch (op)
205 {
206 case ReductionOperation::ARG_IDX_MAX:
207 case ReductionOperation::ARG_IDX_MIN:
208 case ReductionOperation::MIN:
209 case ReductionOperation::MAX:
210 {
211 init_res_value = static_cast<T>(*input_ptr);
212 break;
213 }
214 case ReductionOperation::PROD:
215 {
216 init_res_value = static_cast<T>(1.f);
217 break;
218 }
219 default:
220 break;
221 }
222 auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
223 uint32x4x4_t vec_res_idx{{0}};
224
225 // Compute window_step_x elements per iteration
226 int x = window_start_x;
227 for (; x <= (window_end_x - window_step_x); x += window_step_x)
228 {
229 const auto vec_elements = wrapper::vloadq(input_ptr + x);
230 switch (op)
231 {
232 case ReductionOperation::SUM_SQUARE:
233 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
234 break;
235 case ReductionOperation::MEAN_SUM:
236 case ReductionOperation::SUM:
237 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
238 break;
239 case ReductionOperation::PROD:
240 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
241 break;
242 case ReductionOperation::ARG_IDX_MIN:
243 {
244 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
245 vec_res_idx = calculate_index(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
246 vec_res_value = temp_vec_res_value;
247 break;
248 }
249 case ReductionOperation::ARG_IDX_MAX:
250 {
251 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
252 vec_res_idx = calculate_index(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
253 vec_res_value = temp_vec_res_value;
254 break;
255 }
256 case ReductionOperation::MIN:
257 {
258 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
259 break;
260 }
261 case ReductionOperation::MAX:
262 {
263 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
264 break;
265 }
266 default:
267 ARM_COMPUTE_ERROR("Not supported");
268 }
269 }
270
271 switch (op)
272 {
273 case ReductionOperation::SUM:
274 case ReductionOperation::MEAN_SUM:
275 case ReductionOperation::SUM_SQUARE:
276 {
277#ifdef ARM_COMPUTE_DEBUG_ENABLED
278 auto res = static_cast<T>(0.f);
279 for (int i = 0; i < S; ++i)
280 {
281 res += wrapper::vgetlane(vec_res_value, i);
282 }
283#else // ARM_COMPUTE_DEBUG_ENABLED
284 auto carry_res =
285 wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
286 for (int i = 0; i < S / 4; ++i)
287 {
288 carry_res = wrapper::vpadd(carry_res, carry_res);
289 }
290 auto res = wrapper::vgetlane(carry_res, 0);
291#endif // ARM_COMPUTE_DEBUG_ENABLED
292 if (op == ReductionOperation::SUM_SQUARE)
293 {
294 // Compute left-over elements
295 for (; x < window_end_x; ++x)
296 {
297 res += (*(input_ptr + x)) * (*(input_ptr + x));
298 }
299 }
300 else
301 {
302 // Compute left-over elements
303 for (; x < window_end_x; ++x)
304 {
305 res += *(input_ptr + x);
306 }
307 }
308
309 if (op == ReductionOperation::MEAN_SUM)
310 {
311 res /= input_dim_0;
312 }
313
314 *(reinterpret_cast<T *>(output.ptr())) = res;
315 break;
316 }
317 case ReductionOperation::PROD:
318 {
319 auto carry_res =
320 wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
321 T res = 1;
322 for (int i = 0; i < S / 2; ++i)
323 {
324 res *= wrapper::vgetlane(carry_res, i);
325 }
326
327 // Compute left-over elements
328 for (; x < window_end_x; ++x)
329 {
330 res *= *(input_ptr + x);
331 }
332
333 *(reinterpret_cast<T *>(output.ptr())) = res;
334 break;
335 }
336 case ReductionOperation::ARG_IDX_MIN:
337 {
338 auto idx = calculate_vector_index(vec_res_idx, vec_res_value, op);
339 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
340
341 // Compute left-over elements
342 for (; x < window_end_x; ++x)
343 {
344 if (*(input_ptr + x) < res)
345 {
346 idx = x;
347 res = *(input_ptr + x);
348 }
349 }
350 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
351 break;
352 }
353 case ReductionOperation::ARG_IDX_MAX:
354 {
355 auto idx = calculate_vector_index(vec_res_idx, vec_res_value, op);
356 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
357
358 // Compute left-over elements
359 for (; x < window_end_x; ++x)
360 {
361 if (*(input_ptr + x) > res)
362 {
363 idx = x;
364 res = *(input_ptr + x);
365 }
366 }
367 *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
368 break;
369 }
370 case ReductionOperation::MIN:
371 {
372 auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
373
374 // Compute left-over elements
375 for (; x < window_end_x; ++x)
376 {
377 res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
378 }
379 *(reinterpret_cast<T *>(output.ptr())) = res;
380 break;
381 }
382 case ReductionOperation::MAX:
383 {
384 auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
385
386 // Compute left-over elements
387 for (; x < window_end_x; ++x)
388 {
389 res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
390 }
391 *(reinterpret_cast<T *>(output.ptr())) = res;
392 break;
393 }
394 default:
395 ARM_COMPUTE_ERROR("Not supported");
396 }
397 },
398 input, output);
399 }
400};
401
402template <typename T, int S>
403struct RedOpYZW
404{
405 /** SIMD vector tag type. */
406 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
407 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
408
409 inline void operator()(const Window &in_window,
410 Window &out_window,
411 const ITensor *in,
412 ITensor *out,
413 int axis,
414 const ReductionOperation op)
415 {
416 const TensorInfo in_info = *(in->info());
417 const int window_step_x = 16 / sizeof(T);
418 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
419 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
420 // As it split over x-axis, need to set the correct spiltted window start and end.
421 const auto window_start_x = static_cast<int>(0);
422 const auto window_end_x = static_cast<int>(in_window.shape().x());
423
424 Window in_win_no_pad = in_window;
425 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
426 Window out_win_no_pad = out_window;
427 out_win_no_pad.set(Window::DimX,
428 Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
429
430 Iterator input(in, in_win_no_pad);
431 Iterator output(out, out_win_no_pad);
432
433 execute_window_loop(
434 in_win_no_pad,
435 [&](const Coordinates &)
436 {
437 const auto input_ptr = reinterpret_cast<T *>(input.ptr());
438
439 // Compute window_step_x elements per iteration
440 int x = window_start_x;
441 for (; x <= (window_end_x - window_step_x); x += window_step_x)
442 {
443 neon_vector vec_res_value = {0};
444 switch (op)
445 {
446 case ReductionOperation::ARG_IDX_MAX:
447 case ReductionOperation::ARG_IDX_MIN:
448 case ReductionOperation::MIN:
449 case ReductionOperation::MAX:
450 {
451 vec_res_value = wrapper::vloadq(input_ptr + x);
452 break;
453 }
454 case ReductionOperation::PROD:
455 {
456 vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
457 break;
458 }
459 default:
460 {
461 vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
462 break;
463 }
464 }
465 uint32x4x4_t vec_res_idx{{0}};
466
467 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
468 {
469 const T *in_ptr =
470 reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
471 const auto vec_elements = wrapper::vloadq(in_ptr);
472 switch (op)
473 {
474 case ReductionOperation::SUM:
475 case ReductionOperation::MEAN_SUM:
476 vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
477 break;
478 case ReductionOperation::SUM_SQUARE:
479 vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
480 break;
481 case ReductionOperation::PROD:
482 vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
483 break;
484 case ReductionOperation::ARG_IDX_MIN:
485 {
486 auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
487 vec_res_idx =
488 calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
489 vec_res_value = temp_vec_res_value;
490 break;
491 }
492 case ReductionOperation::ARG_IDX_MAX:
493 {
494 auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
495 vec_res_idx =
496 calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
497 vec_res_value = temp_vec_res_value;
498 break;
499 }
500 case ReductionOperation::MIN:
501 {
502 vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
503 break;
504 }
505 case ReductionOperation::MAX:
506 {
507 vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
508 break;
509 }
510 default:
511 ARM_COMPUTE_ERROR("Not supported");
512 }
513 }
514
515 if (op == ReductionOperation::MEAN_SUM)
516 {
517 auto vec_width_inv =
518 wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
519 vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
520 }
521
522 if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
523 {
524 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
525#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
526 if (std::is_same<T, float16_t>::value)
527 {
528 wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
529 }
530#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
531 }
532 else
533 {
534 wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
535 }
536 }
537
538 // Compute left-over elements
539 for (; x < window_end_x; ++x)
540 {
541 auto res_value = 0.f;
542 switch (op)
543 {
544 case ReductionOperation::ARG_IDX_MAX:
545 case ReductionOperation::ARG_IDX_MIN:
546 case ReductionOperation::MIN:
547 case ReductionOperation::MAX:
548 {
549 res_value = *(input_ptr + x);
550 break;
551 }
552 case ReductionOperation::PROD:
553 {
554 res_value = static_cast<T>(1.f);
555 break;
556 }
557 default:
558 {
559 res_value = static_cast<T>(0.f);
560 break;
561 }
562 }
563
564 uint32_t res_idx = 0;
565 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
566 {
567 const T *in_ptr =
568 reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
569
570 switch (op)
571 {
572 case ReductionOperation::SUM:
573 case ReductionOperation::MEAN_SUM:
574 res_value += *in_ptr;
575 break;
576 case ReductionOperation::SUM_SQUARE:
577 res_value += *in_ptr * *in_ptr;
578 break;
579 case ReductionOperation::PROD:
580 res_value *= *in_ptr;
581 break;
582 case ReductionOperation::ARG_IDX_MIN:
583 {
584 if (*in_ptr < res_value)
585 {
586 res_value = *in_ptr;
587 res_idx = dim;
588 }
589 break;
590 }
591 case ReductionOperation::ARG_IDX_MAX:
592 {
593 if (*in_ptr > res_value)
594 {
595 res_value = *in_ptr;
596 res_idx = dim;
597 }
598 break;
599 }
600 case ReductionOperation::MIN:
601 {
602 res_value = *in_ptr < res_value ? *in_ptr : res_value;
603 break;
604 }
605 case ReductionOperation::MAX:
606 {
607 res_value = *in_ptr > res_value ? *in_ptr : res_value;
608 break;
609 }
610 default:
611 ARM_COMPUTE_ERROR("Not supported");
612 }
613 }
614
615 if (op == ReductionOperation::MEAN_SUM)
616 {
617 res_value /= in_info.dimension(axis);
618 }
619
620 if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
621 {
622 *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
623 }
624 else
625 {
626 *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
627 }
628 }
629 },
630 input, output);
631 }
632};
633
634template <typename T, int S, int axis, ReductionOperation op>
635struct RedOpYZW_complex
636{
637 /** SIMD vector tag type. */
638 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
639 using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
640
641 inline void operator()(
642 const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
643 {
644 ARM_COMPUTE_ERROR_ON(axis != 2);
645 ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM);
646
647 const TensorInfo in_info = *(in->info());
648 const size_t stride_z = in_info.strides_in_bytes()[axis];
649 const int window_step_x = 16 / sizeof(T);
650 const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
651 const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
652 // As it split over x-axis, need to set the correct spiltted window start and end.
653 const auto window_start_x = static_cast<int>(0);
654 const auto window_end_x = static_cast<int>(in_window.shape().x());
655
656 Window in_win_no_pad = in_window;
657 in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
658 Window out_win_no_pad = out_window;
659 out_win_no_pad.set(Window::DimX,
660 Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
661
662 Iterator input(in, in_win_no_pad);
663 Iterator output(out, out_win_no_pad);
664
665 execute_window_loop(
666 in_win_no_pad,
667 [&](const Coordinates &)
668 {
669 // Compute window_step_x elements per iteration
670 int x = window_start_x;
671 for (; x <= (window_end_x - window_step_x); x += window_step_x)
672 {
673 neon_vector vec_res_value_0 = {0};
674 neon_vector vec_res_value_1 = {0};
675
676 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
677 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
678
679 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
680 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
681 {
682 T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
683 T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
684
685 const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
686 const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
687
688 vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
689 vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
690 }
691
692 wrapper::vstore(out_ptr, vec_res_value_0);
693 wrapper::vstore(out_ptr + 4, vec_res_value_1);
694 }
695
696 // Compute left-over elements
697 for (; x < window_end_x; ++x)
698 {
699 auto res_value_0 = 0.f;
700 auto res_value_1 = 0.f;
701
702 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
703 for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
704 {
705 T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
706 res_value_0 += *in_ptr;
707 res_value_1 += *(in_ptr + 1);
708 }
709 *out_ptr = res_value_0;
710 *(out_ptr + 1) = res_value_1;
711 }
712 },
713 input, output);
714 }
715};
716
717} // namespace arm_compute
718#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H