blob: 487f6ae051782b172f11c751c46259aaa93516b8 [file] [log] [blame]
Dana Zlotnika538ae52022-02-21 13:12:41 +02001/*
Pablo Marquez Tello7ce8a832023-08-31 16:00:50 +01002 * Copyright (c) 2021-2023 Arm Limited.
Dana Zlotnika538ae52022-02-21 13:12:41 +02003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/cpu/kernels/softmax/generic/neon/impl.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010025
Dana Zlotnika538ae52022-02-21 13:12:41 +020026#include "support/SaturateCast.h"
27
28namespace arm_compute
29{
30namespace cpu
31{
Gunes Bayirfadc9b12023-11-07 05:43:07 +000032template <typename T, bool IS_LOG>
33void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
Dana Zlotnika538ae52022-02-21 13:12:41 +020034{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010035 static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
Dana Zlotnika538ae52022-02-21 13:12:41 +020036 "quantized type should be either qasymm8_t or qasymm8_signed_t.");
37
Dana Zlotnika538ae52022-02-21 13:12:41 +020038 const int input_width = in->info()->valid_region().shape.x();
39
Gunes Bayirfadc9b12023-11-07 05:43:07 +000040 const float scale_beta = -beta * in->info()->quantization_info().uniform().scale;
41 const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta);
Dana Zlotnika538ae52022-02-21 13:12:41 +020042
Gunes Bayirfadc9b12023-11-07 05:43:07 +000043 Iterator in_it(in, window);
44 Iterator out_it(out, window);
45
Dana Zlotnika538ae52022-02-21 13:12:41 +020046 constexpr int vec_size = 16;
47
Gunes Bayirfadc9b12023-11-07 05:43:07 +000048#ifndef __aarch64__
49 const int sum_stages = log2(vec_size >> 1);
50#endif // __aarch64__
51
52 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
53
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010054 execute_window_loop(
55 window,
56 [&](const Coordinates &)
Dana Zlotnika538ae52022-02-21 13:12:41 +020057 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010058 /* Get pointers */
Gunes Bayirfadc9b12023-11-07 05:43:07 +000059 const T *in_ptr = reinterpret_cast<const T *>(in_it.ptr());
60 T *out_ptr = reinterpret_cast<T *>(out_it.ptr());
61 float *tmp_ptr = reinterpret_cast<float *>(tmp);
Dana Zlotnika538ae52022-02-21 13:12:41 +020062
Gunes Bayirfadc9b12023-11-07 05:43:07 +000063 T max_val;
64
65 /* Compute Max */
66 {
67 // Init max value
68 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
69 int x = 0;
70
71 for (; x <= (input_width - vec_size); x += vec_size)
72 {
73 const auto current_value = wrapper::vloadq(in_ptr + x);
74 vec_max = wrapper::vmax(vec_max, current_value);
75 }
76
77#ifdef __aarch64__
78 max_val = wrapper::vmaxv(vec_max);
79#else // __aarch64__
80 auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
81
82 for (int i = 0; i < sum_stages; ++i)
83 {
84 carry_max = wrapper::vpmax(carry_max, carry_max);
85 }
86
87 max_val = wrapper::vgetlane(carry_max, 0);
88#endif // __aarch64__
89
90 // Compute left-over elements
91 for (; x < input_width; ++x)
92 {
93 max_val = std::max(*(in_ptr + x), max_val);
94 }
95 } // Compute Max
96
97 float sum_transformed{};
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010098
99 /* Compute exponentials and sum */
Dana Zlotnika538ae52022-02-21 13:12:41 +0200100 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100101 /* Get max value */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100102 const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{});
Dana Zlotnika538ae52022-02-21 13:12:41 +0200103
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100104 /* Init sum to zero */
105 float32x4x4_t vec_sum = {
106 vdupq_n_f32(0.f),
107 vdupq_n_f32(0.f),
108 vdupq_n_f32(0.f),
109 vdupq_n_f32(0.f),
110 };
Dana Zlotnika538ae52022-02-21 13:12:41 +0200111
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100112 /* Loop over row and compute exponentials and sum */
113 int x = 0;
114 for (; x <= (input_width - vec_size); x += vec_size)
Dana Zlotnika538ae52022-02-21 13:12:41 +0200115 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000116 auto vec_elements = wrapper::vloadq(in_ptr + x);
117 vec_elements = wrapper::vqsub(vec_max, vec_elements);
118 float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
Dana Zlotnika538ae52022-02-21 13:12:41 +0200119
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000120 if (IS_LOG)
Dana Zlotnika538ae52022-02-21 13:12:41 +0200121 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100122 vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
123 vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
124 vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
125 vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
126 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
127 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
128 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
129 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
130 }
131 else
Dana Zlotnika538ae52022-02-21 13:12:41 +0200132 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100133 vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
134 vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
135 vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
136 vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
137 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
138 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
139 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
140 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
Dana Zlotnika538ae52022-02-21 13:12:41 +0200141 }
142
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100143 vst4q_f32(tmp_ptr + x, vec_elements_flt);
Dana Zlotnika538ae52022-02-21 13:12:41 +0200144 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100145
146 /* Reduce sum */
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000147 const float32x4_t sum_16_byte =
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100148 vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000149
150 float sum;
151
152#ifdef __aarch64__
153 sum = wrapper::vaddv(sum_16_byte);
154#else // __aarch64__
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100155 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
156 sum_res = vpadd_f32(sum_res, sum_res);
157 sum = wrapper::vgetlane(sum_res, 0);
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000158#endif // __aarch64__
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100159
160 /* Run remaining elements */
161 for (; x < input_width; ++x)
Dana Zlotnika538ae52022-02-21 13:12:41 +0200162 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100163 float element{};
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000164 if (IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100165 {
166 element = (max_val - in_ptr[x]) * scale_beta;
167 sum += std::exp(element);
168 }
169 else
170 {
171 element = std::exp((max_val - in_ptr[x]) * scale_beta);
172 sum += element;
173 }
174
175 tmp_ptr[x] = element;
176 }
177
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000178 if (!IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100179 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000180 sum_transformed = 256.f / sum;
Dana Zlotnika538ae52022-02-21 13:12:41 +0200181 }
182 else
183 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000184 sum_transformed = std::log(sum);
Dana Zlotnika538ae52022-02-21 13:12:41 +0200185 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000186 } // Compute exponentials and sum
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100187
188 /* Normalize exponentials */
189 {
190 constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000191
192 const float32x4_t sum_vec = vdupq_n_f32(sum_transformed);
193
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100194 /* Loop over row and compute softmax */
195 int x = 0;
196 for (; x <= (input_width - vec_size); x += vec_size)
197 {
198 using int_vec_type = wrapper::traits::neon_vector_t<T, 16>;
199 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
200 int_vec_type normalized_value{};
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000201 if (IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100202 {
203 const float32x4x4_t sub = {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000204 vsubq_f32(vec_in.val[0], sum_vec),
205 vsubq_f32(vec_in.val[1], sum_vec),
206 vsubq_f32(vec_in.val[2], sum_vec),
207 vsubq_f32(vec_in.val[3], sum_vec),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100208 };
209 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
210 }
211 else
212 {
213 float32x4x4_t mul = {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000214 vmulq_f32(vec_in.val[0], sum_vec),
215 vmulq_f32(vec_in.val[1], sum_vec),
216 vmulq_f32(vec_in.val[2], sum_vec),
217 vmulq_f32(vec_in.val[3], sum_vec),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100218 };
219
220 if (is_qasymm8_signed)
221 {
222 const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
223 mul.val[0] = wrapper::vsub(mul.val[0], offset_vec);
224 mul.val[1] = wrapper::vsub(mul.val[1], offset_vec);
225 mul.val[2] = wrapper::vsub(mul.val[2], offset_vec);
226 mul.val[3] = wrapper::vsub(mul.val[3], offset_vec);
227 }
228
229 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
230 }
231 wrapper::vstore(out_ptr + x, normalized_value);
232 }
233 /* Run remaining elements */
234 for (; x < input_width; ++x)
235 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000236 if (IS_LOG)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100237 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000238 out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum_transformed);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100239 }
240 else
241 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000242 out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_transformed) -
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100243 (is_qasymm8_signed ? 128.f : 0));
244 }
245 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000246 } // Normalize exponentials
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100247 },
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000248 in_it, out_it);
Dana Zlotnika538ae52022-02-21 13:12:41 +0200249}
250
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000251template void neon_softmax_quantized<qasymm8_signed_t, true>(
252 const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
253
254template void neon_softmax_quantized<qasymm8_signed_t, false>(
255 const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
256
257template void neon_softmax_quantized<qasymm8_t, true>(
258 const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
259
260template void neon_softmax_quantized<qasymm8_t, false>(
261 const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
Dana Zlotnika538ae52022-02-21 13:12:41 +0200262} // namespace cpu
263} // namespace arm_compute