blob: e417271d0e1d8a54b0030bab97f3f0c70ee350a1 [file] [log] [blame]
Michalis Spyroub5a450a2021-01-06 17:40:30 +00001/*
Omar Al Khatib93e743f2024-01-02 14:45:07 +00002 * Copyright (c) 2021-2024 Arm Limited.
Michalis Spyroub5a450a2021-01-06 17:40:30 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Gunes Bayirfadc9b12023-11-07 05:43:07 +000024#ifndef ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H
25#define ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H
Michalis Spyroub5a450a2021-01-06 17:40:30 +000026
Dana Zlotnika538ae52022-02-21 13:12:41 +020027#include "arm_compute/core/Helpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010028
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010029#include "src/core/NEON/NEMath.h"
30#include "src/core/NEON/wrapper/wrapper.h"
Michalis Spyroub5a450a2021-01-06 17:40:30 +000031
32namespace arm_compute
33{
34namespace cpu
35{
Gunes Bayirfadc9b12023-11-07 05:43:07 +000036
37#ifdef __aarch64__
38namespace
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010039{
Gunes Bayirfadc9b12023-11-07 05:43:07 +000040// These helper functions are added because vaddv does not exist for fp16,
41// and, therefore, is not part of the wrapper::vaddv interface.
42#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
43inline float16_t wrapper_vaddv(const float16x8_t &a, int sum_stages)
44{
45 auto sum_res = wrapper::vpadd(wrapper::vgethigh(a), wrapper::vgetlow(a));
46 for (int i = 0; i < sum_stages; ++i)
47 {
48 sum_res = wrapper::vpadd(sum_res, sum_res);
49 }
50 return wrapper::vgetlane(sum_res, 0);
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010051}
Gunes Bayirfadc9b12023-11-07 05:43:07 +000052#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyroub5a450a2021-01-06 17:40:30 +000053
Gunes Bayirfadc9b12023-11-07 05:43:07 +000054inline float wrapper_vaddv(const float32x4_t &a, int sum_stages)
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010055{
Gunes Bayirfadc9b12023-11-07 05:43:07 +000056 ARM_COMPUTE_UNUSED(sum_stages);
57 return wrapper::vaddv(a);
58}
59} // namespace
60#endif // __aarch64__
61
62// The template implementation for float data types is stored in the header file because
63// we need all fp16 instantiated code to live in fp16.cpp files.
64template <typename T, bool IS_LOG>
Omar Al Khatib93e743f2024-01-02 14:45:07 +000065void neon_softmax_x_float(const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
Gunes Bayirfadc9b12023-11-07 05:43:07 +000066{
Omar Al Khatib93e743f2024-01-02 14:45:07 +000067 ARM_COMPUTE_UNUSED(axis);
Gunes Bayirfadc9b12023-11-07 05:43:07 +000068 ARM_COMPUTE_UNUSED(tmp);
69
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010070 const int input_width = in->info()->valid_region().shape.x();
71
72 Iterator in_it(in, window);
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010073 Iterator out_it(out, window);
74
75 /** SIMD vector tag type. */
76 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
77
Gunes Bayirfadc9b12023-11-07 05:43:07 +000078 constexpr int vec_size = 16 / sizeof(T);
79
80 const int sum_stages = log2(vec_size >> 1);
81
82 const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010083
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010084 execute_window_loop(
85 window,
86 [&](const Coordinates &)
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010087 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010088 /* Get pointers */
Gunes Bayirfadc9b12023-11-07 05:43:07 +000089 const T *in_ptr = reinterpret_cast<const T *>(in_it.ptr());
90 T *out_ptr = reinterpret_cast<T *>(out_it.ptr());
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +010091
Gunes Bayirfadc9b12023-11-07 05:43:07 +000092 T max_val;
93
94 /* Compute Max */
95 {
96 // Init max value
97 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
98 int x = 0;
99
100 for (; x <= (input_width - vec_size); x += vec_size)
101 {
102 const auto current_value = wrapper::vloadq(in_ptr + x);
103 vec_max = wrapper::vmax(vec_max, current_value);
104 }
105
106#ifdef __aarch64__
107 max_val = wrapper::vmaxv(vec_max);
108#else // __aarch64__
109 auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
110
111 for (int i = 0; i < sum_stages; ++i)
112 {
113 carry_max = wrapper::vpmax(carry_max, carry_max);
114 }
115
116 max_val = wrapper::vgetlane(carry_max, 0);
117#endif // __aarch64__
118
119 // Compute left-over elements
120 for (; x < input_width; ++x)
121 {
122 max_val = std::max(*(in_ptr + x), max_val);
123 }
124 } // compute max
125
126 T sum_transformed{};
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100127
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100128 /* Compute exponentials and sum */
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100129 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100130 /* Get max value */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100131 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
132
133 /* Init sum to zero */
134 auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
135
136 /* Loop over row and compute exponentials and sum */
137 int x = 0;
138 for (; x <= (input_width - vec_size); x += vec_size)
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100139 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100140 auto vec_elements = wrapper::vloadq(in_ptr + x);
141 vec_elements = wrapper::vsub(vec_elements, vec_max);
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000142 if (IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100143 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000144 vec_elements = wrapper::vmul(vec_elements, beta_vec);
145 vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100146 }
147 else
148 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000149 vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
150 vec_sum = wrapper::vadd(vec_sum, vec_elements);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100151 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000152 wrapper::vstore(out_ptr + x, vec_elements);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100153 }
154
155 /* Reduce sum */
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000156 T sum{};
157#ifdef __aarch64__
158 sum = wrapper_vaddv(vec_sum, sum_stages);
159#else // __aarch64__
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100160 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
161 for (int i = 0; i < sum_stages; ++i)
162 {
163 sum_res = wrapper::vpadd(sum_res, sum_res);
164 }
165 sum = wrapper::vgetlane(sum_res, 0);
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000166#endif // __aarch64__
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100167
168 /* Run remaining elements */
169 for (; x < input_width; ++x)
170 {
171 T element{};
172
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000173 if (IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100174 {
175 element = (in_ptr[x] - max_val) * beta;
176 sum += std::exp(element);
177 }
178 else
179 {
180 element = std::exp((in_ptr[x] - max_val) * beta);
181 sum += element;
182 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000183
184 out_ptr[x] = element;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100185 }
186
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000187 if (!IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100188 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000189 sum_transformed = T(1) / sum;
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100190 }
191 else
192 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000193 sum_transformed = static_cast<T>(std::log(sum));
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100194 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000195 } // Compute exponentials and sum
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100196
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100197 /* Normalize exponentials */
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100198 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000199 const auto sum_vec = wrapper::vdup_n(static_cast<T>(sum_transformed), ExactTagType{});
200
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100201 /* Loop over row and compute softmax */
202 int x = 0;
203 for (; x <= (input_width - vec_size); x += vec_size)
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100204 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000205 const auto vec_in = wrapper::vloadq(out_ptr + x);
206 if (IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100207 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000208 wrapper::vstore(out_ptr + x, wrapper::vsub(vec_in, sum_vec));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100209 }
210 else
211 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000212 wrapper::vstore(out_ptr + x, wrapper::vmul(vec_in, sum_vec));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100213 }
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100214 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000215
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100216 /* Run remaining elements */
217 for (; x < input_width; ++x)
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100218 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000219 if (IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100220 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000221 out_ptr[x] = out_ptr[x] - sum_transformed;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100222 }
223 else
224 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000225 out_ptr[x] = out_ptr[x] * sum_transformed;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100226 }
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100227 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000228 } // Normalize exponentials
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100229 },
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000230 in_it, out_it);
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +0100231}
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000232template <typename T, bool IS_LOG>
233void neon_softmax_non_x_float(
234 const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
235{
236 ARM_COMPUTE_UNUSED(tmp);
237
238 Iterator in_it(in, window);
239 Iterator out_it(out, window);
240
241 /** SIMD vector tag type. */
242 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
243
244 const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
245 constexpr int vec_size = 16 / sizeof(T);
246 const ITensorInfo *in_info = in->info();
247 const ITensorInfo *out_info = out->info();
248 const int x_width = in_info->valid_region().shape.x();
249 const unsigned int in_axis_stride = in_info->strides_in_bytes()[axis];
250 const unsigned int out_axis_stride = out_info->strides_in_bytes()[axis];
251 const int axis_width = in_info->dimension(axis);
252
253 execute_window_loop(
254 window,
255 [&](const Coordinates &winCoords)
256 {
257 const bool vector_exceeds_bounds = (winCoords[0] + vec_size) > x_width;
258
259 /* Get pointers */
260 const uint8_t *in_ptr = in_it.ptr();
261 uint8_t *out_ptr = out_it.ptr();
262
263 // Init max value
264 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
265
266 /* Compute Max */
267 {
268 if (!vector_exceeds_bounds)
269 {
270 int i = 0;
271 for (; i < axis_width; ++i)
272 {
273 const auto current_value =
274 wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
275 vec_max = wrapper::vmax(vec_max, current_value);
276 }
277 }
278 else
279 {
280 int i = 0;
281 for (; i < axis_width; ++i)
282 {
283 const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
284 int j = 0;
285 for (; j < (x_width - winCoords[0]); ++j)
286 {
287 const auto current_value = *(base_ptr_in + j);
288 vec_max[j] = std::max(vec_max[j], current_value);
289 }
290 }
291 }
292 } // compute max
293
294 auto vec_sum_transformed = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
295
296 auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
297 /* Init sum to zero */
298 auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
299
300 /* Compute exponentials and sum */
301 {
302 if (!vector_exceeds_bounds)
303 {
304 const auto vec_one = wrapper::vdup_n(static_cast<T>(1), ExactTagType{});
305 /* Loop over row and compute exponentials and sum */
306 int i = 0;
307 for (; i < axis_width; ++i)
308 {
309 vec_elements = wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
310 vec_elements = wrapper::vsub(vec_elements, vec_max);
311 if (IS_LOG)
312 {
313 vec_elements = wrapper::vmul(vec_elements, beta_vec);
314 vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
315 }
316 else
317 {
318 vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
319 vec_sum = wrapper::vadd(vec_sum, vec_elements);
320 }
321
322 wrapper::vstore(reinterpret_cast<T *>((i * out_axis_stride) + out_ptr), vec_elements);
323 }
324
325 if (!IS_LOG)
326 {
327 vec_sum_transformed = wrapper::vdiv(vec_one, vec_sum);
328 }
329 else
330 {
331 vec_sum_transformed = wrapper::vlog(vec_sum);
332 }
333 }
334 else
335 {
336 int i = 0;
337 for (; i < axis_width; ++i)
338 {
339 const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
340 T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
341 int j = 0;
342 for (; j < (x_width - winCoords[0]); ++j)
343 {
344 vec_elements[j] = *(base_ptr_in + j);
345 vec_elements[j] -= vec_max[j];
346 if (IS_LOG)
347 {
348 vec_elements[j] *= beta;
349 vec_sum[j] += std::exp(vec_elements[j]);
350 }
351 else
352 {
353 vec_elements[j] = std::exp(vec_elements[j] * beta);
354 vec_sum[j] += vec_elements[j];
355 }
356 *(base_ptr_out + j) = vec_elements[j];
357 }
358 }
359 int j = 0;
360 for (; j < (x_width - winCoords[0]); ++j)
361 {
362 if (!IS_LOG)
363 {
364 vec_sum_transformed[j] = 1 / vec_sum[j];
365 }
366 else
367 {
368 vec_sum_transformed[j] = std::log(vec_sum[j]);
369 }
370 }
371 }
372 } // Compute exponentials and sum
373
374 /* Normalize exponentials */
375 {
376 if (!vector_exceeds_bounds)
377 {
378 /* Loop over row and compute softmax */
379 int i = 0;
380 for (; i < axis_width; ++i)
381 {
382 T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
383 auto vec_in = wrapper::vloadq(base_ptr_out);
384 if (IS_LOG)
385 {
386 wrapper::vstore(base_ptr_out, wrapper::vsub(vec_in, vec_sum_transformed));
387 }
388 else
389 {
390 wrapper::vstore(base_ptr_out, wrapper::vmul(vec_in, vec_sum_transformed));
391 }
392 }
393 }
394 else
395 {
396 int i = 0;
397 for (; i < axis_width; ++i)
398 {
399 T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
400 int j = 0;
401 for (; j < (x_width - winCoords[0]); ++j)
402 {
403 if (IS_LOG)
404 {
405 *(base_ptr_out + j) -= vec_sum_transformed[j];
406 }
407 else
408 {
409 *(base_ptr_out + j) *= vec_sum_transformed[j];
410 }
411 }
412 }
413 }
414 } // Normalize exponentials
415 },
416 in_it, out_it);
417}
418template <typename T, bool IS_LOG>
419void neon_softmax_x_quantized(
420 const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000421
422template <typename T, bool IS_LOG>
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000423void neon_softmax_non_x_quantized(
424 const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
Michalis Spyroub5a450a2021-01-06 17:40:30 +0000425} // namespace cpu
426} // namespace arm_compute
427
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000428#endif // ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H