blob: bc5b0c0696ec8829d6629cddeb8abf2c8abdd6cb [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
Manuel Bottini21079dd2019-10-29 17:20:09 +000033#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/Window.h"
Georgios Pinitas303f0db2018-11-19 11:56:51 +000038#include "arm_compute/core/utils/misc/SaturateCast.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039
40#include <algorithm>
41#include <arm_neon.h>
42#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000045namespace arm_compute
46{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000047template <typename float_vec_type, typename int_vec_type>
48int_vec_type convert_float_to_int(const float_vec_type &in);
49
50template <typename float_vec_type, typename int_vec_type>
51float_vec_type convert_int_to_float(const int_vec_type &in);
52
53template <>
54uint8x16_t convert_float_to_int<float32x4x4_t, uint8x16_t>(const float32x4x4_t &in)
55{
56 uint8x16_t out;
57 convert_float32x4x4_to_uint8x16(in, out);
58 return out;
59}
60
61template <>
62int8x16_t convert_float_to_int<float32x4x4_t, int8x16_t>(const float32x4x4_t &in)
63{
64 int8x16_t out;
65 convert_float32x4x4_to_int8x16(in, out);
66 return out;
67}
68
69template <>
70float32x4x4_t convert_int_to_float<float32x4x4_t, uint8x16_t>(const uint8x16_t &in)
71{
72 return convert_uint8x16_to_float32x4x4(in);
73}
74
75template <>
76float32x4x4_t convert_int_to_float<float32x4x4_t, int8x16_t>(const int8x16_t &in)
77{
78 return convert_int8x16_to_float32x4x4(in);
79}
80
Anthony Barbier6ff3b192017-09-04 18:44:23 +010081namespace
82{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000083Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000084{
Anthony Barbiereaefd002018-07-20 17:49:35 +010085 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000086 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +010087
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000088 // Validate in case of configured output
89 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010090 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000091 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000092 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
93 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010094 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000095
96 return Status{};
97}
98
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000099template <typename T>
100void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
101{
Manuel Bottini21079dd2019-10-29 17:20:09 +0000102 /** NEON vector tag type. */
103 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
104
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000105 constexpr int window_step_x = 16 / sizeof(T);
106 const auto window_start_x = static_cast<int>(window.x().start());
107 const auto window_end_x = static_cast<int>(window.x().end());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000108
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000109 Window win{ window };
110 win.set(Window::DimX, Window::Dimension(0, 1, 1));
111 Iterator input(&in, win);
112 Iterator output(&out, win);
113
114 const int sum_stages = log2(window_step_x / 2);
115 execute_window_loop(win, [&](const Coordinates &)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000116 {
117 // Get pointers
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000118 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000119 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
120
121 // Init max value
Manuel Bottini21079dd2019-10-29 17:20:09 +0000122 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000123 int x = window_start_x;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000124
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000125 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000126 {
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000127 const auto current_value = wrapper::vloadq(in_ptr + x);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000128 vec_max = wrapper::vmax(vec_max, current_value);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000129 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000130 auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
131
132 for(int i = 0; i < sum_stages; ++i)
133 {
134 carry_max = wrapper::vpmax(carry_max, carry_max);
135 }
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000136 T max_val = wrapper::vgetlane(carry_max, 0);
137
138 // Compute left-over elements
139 for(; x < window_end_x; ++x)
140 {
141 max_val = *(in_ptr + x) > max_val ? *(in_ptr + x) : max_val;
142 }
143
144 *out_ptr = max_val;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000145 },
146 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100147}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100148} // namespace
149
150NELogits1DMaxKernel::NELogits1DMaxKernel()
151 : _func(nullptr), _border_size()
152{
153}
154
155BorderSize NELogits1DMaxKernel::border_size() const
156{
157 return _border_size;
158}
159
160void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
161{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000162 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000163 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000164 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000165 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
166 // Configure kernel window
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000167
168 // Softmax across the x dimension
169 const TensorShape output_shape = TensorShape(input->info()->tensor_shape()).set(0, 1);
170 // Output auto initialization if not yet initialized
171 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->quantization_info());
172
173 Window win = calculate_max_window(*input->info(), Steps());
174 Coordinates coord;
175 coord.set_num_dimensions(output->info()->num_dimensions());
176 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100177
178 switch(input->info()->data_type())
179 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000180 case DataType::QASYMM8:
181 _func = &logits_1d_max<qasymm8_t>;
182 break;
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000183 case DataType::QASYMM8_SIGNED:
184 _func = &logits_1d_max<qasymm8_signed_t>;
185 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000186#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000187 case DataType::F16:
188 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100189 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000190#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000191 case DataType::F32:
192 _func = &logits_1d_max<float>;
193 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194 default:
195 ARM_COMPUTE_ERROR("Unsupported data type.");
196 }
197
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000198 _input = input;
199 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100200
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000201 const int input_width = input->info()->valid_region().shape.x();
202 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
203 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
204
205 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
206
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000207 INEKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000208}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000210Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
211{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000212 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000213 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100214
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000215 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216}
217
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100218void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100220 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100221 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
222 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
223 ARM_COMPUTE_ERROR_ON(_func == nullptr);
224
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000225 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100226}
227
228namespace
229{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000230Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000231 const ITensorInfo &output, const float beta, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100233 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000234 // Check input
Anthony Barbiereaefd002018-07-20 17:49:35 +0100235 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000236 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100237
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000238 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100239
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000240 // Check max
241 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
242 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000243 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100244
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000245 // Check output if configured
246 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100247 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000248 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input.data_type(), is_log) : output.quantization_info();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000249 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
250 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000251 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100252 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000254 // Check tmp if configured
255 if(tmp.total_size() != 0)
256 {
257 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
258 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000259 // We could potentially reduce tmp memory if we could predict or make an assumption
260 // on the maximum number of threads that will run in parallel.
261 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
262 }
263
264 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100265}
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000266template <typename T, bool is_log>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000267void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
268{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000269 static_assert(std::is_same<T, qasymm8_t>::value
270 || std::is_same<T, qasymm8_signed_t>::value,
271 "quantized type should be either qasymm8_t or qasymm8_signed_t.");
272
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000273 const int start_x = in.info()->valid_region().anchor.x();
274 const int input_width = in.info()->valid_region().shape.x();
275
Manuel Bottini21079dd2019-10-29 17:20:09 +0000276 const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
277 const auto scale_beta_vec = vdupq_n_f32(scale_beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000278
Manuel Bottini21079dd2019-10-29 17:20:09 +0000279 Iterator in_it(&in, window);
280 Iterator max_it(&max, window);
281 Iterator out_it(&out, window);
282 constexpr int vec_size = 16;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000283
284 execute_window_loop(window, [&](const Coordinates &)
285 {
286 /* Get pointers */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000287 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
288 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000289 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
290
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100291 float sum{};
292 float sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000293
294 /* Compute exponentials and sum */
295 {
296 /* Get max value */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000297 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
298 const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000299
300 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000301 float32x4x4_t vec_sum =
302 {
303 vdupq_n_f32(0.f),
304 vdupq_n_f32(0.f),
305 vdupq_n_f32(0.f),
306 vdupq_n_f32(0.f),
307 };
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000308
309 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000310 int x = 0;
311 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000312 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000313 auto vec_elements = wrapper::vloadq(in_ptr + x);
Georgios Pinitas2cd7a372020-05-12 21:03:56 +0100314 vec_elements = wrapper::vqsub(vec_max, vec_elements);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000315 auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000316
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100317 if(is_log)
318 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000319 vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
320 vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
321 vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
322 vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
323 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
324 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
325 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
326 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100327 }
328 else
329 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000330 vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
331 vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
332 vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
333 vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
334 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
335 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
336 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
337 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100338 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000339
340 vst4q_f32(tmp_ptr + x, vec_elements_flt);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000341 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100342
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000343 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000344 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
345 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
346 sum_res = vpadd_f32(sum_res, sum_res);
347 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000348
349 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000350 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000351 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100352 float element{};
353 if(is_log)
354 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000355 element = (max_val - in_ptr[x]) * scale_beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100356 sum += std::exp(element);
357 }
358 else
359 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000360 element = std::exp((max_val - in_ptr[x]) * scale_beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100361 sum += element;
362 }
363
Manuel Bottini21079dd2019-10-29 17:20:09 +0000364 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000365 }
366
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100367 if(!is_log)
368 {
369 sum_inversed = 256.f / sum;
370 }
Sang-Hoon Parka0205b92020-07-07 09:36:09 +0100371 else
372 {
373 sum = std::log(sum);
374 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000375 }
376
377 /* Normalize exponentials */
378 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000379 constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000380 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000381 int x = 0;
382 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000383 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000384 using int_vec_type = wrapper::traits::neon_vector_t<T, 16>;
Manuel Bottini21079dd2019-10-29 17:20:09 +0000385 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000386 int_vec_type normalized_value{};
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100387 if(is_log)
388 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000389 const float32x4x4_t sub =
390 {
391 vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
392 vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
393 vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
394 vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
395 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000396 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100397 }
398 else
399 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000400 float32x4x4_t mul =
Manuel Bottini21079dd2019-10-29 17:20:09 +0000401 {
402 vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
403 vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
404 vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
405 vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
406 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000407
408 if(is_qasymm8_signed)
409 {
410 const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
411 mul.val[0] = wrapper::vsub(mul.val[0], offset_vec);
412 mul.val[1] = wrapper::vsub(mul.val[1], offset_vec);
413 mul.val[2] = wrapper::vsub(mul.val[2], offset_vec);
414 mul.val[3] = wrapper::vsub(mul.val[3], offset_vec);
415 }
416
417 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000418 }
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000419 wrapper::vstore(out_ptr + x, normalized_value);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000420 }
421 /* Run remaining elements */
422 for(; x < input_width; ++x)
423 {
424 if(is_log)
425 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000426 out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000427 }
428 else
429 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000430 out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) - (is_qasymm8_signed ? 128.f : 0));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100431 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000432 }
433 }
434 },
435 in_it, max_it, out_it);
436}
437
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100438template <typename T, bool is_log = false>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000439void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
440 ITensor &out, const float beta, const Window &window)
441{
442 const int start_x = in.info()->valid_region().anchor.x();
443 const int input_width = in.info()->valid_region().shape.x();
444
445 Iterator in_it(&in, window);
446 Iterator max_it(&max, window);
447 Iterator out_it(&out, window);
448
Manuel Bottini21079dd2019-10-29 17:20:09 +0000449 /** NEON vector tag type. */
450 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
451
452 constexpr int vec_size = 16 / sizeof(T);
453 const int sum_stages = log2(vec_size / 2);
454
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000455 execute_window_loop(window, [&](const Coordinates &)
456 {
457 /* Get pointers */
458 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
459 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
460 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
461
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100462 T sum{};
463 T sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000464
465 /* Compute exponentials and sum */
466 {
467 /* Get max value */
468 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
Manuel Bottini21079dd2019-10-29 17:20:09 +0000469 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000470
471 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000472 auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000473
474 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000475 int x = 0;
476 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000477 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000478 auto vec_elements = wrapper::vloadq(in_ptr + x);
479 vec_elements = wrapper::vsub(vec_elements, vec_max);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100480 if(is_log)
481 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000482 vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
483 vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100484 }
485 else
486 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000487 vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
488 vec_sum = wrapper::vadd(vec_sum, vec_elements);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100489 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000490 wrapper::vstore(tmp_ptr + x, vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000491 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100492
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000493 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000494 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
495 for(int i = 0; i < sum_stages; ++i)
496 {
497 sum_res = wrapper::vpadd(sum_res, sum_res);
498 }
499 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000500
501 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000502 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000503 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100504 T element{};
505
506 if(is_log)
507 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000508 element = (in_ptr[x] - max_val) * beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100509 sum += std::exp(element);
510 }
511 else
512 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000513 element = std::exp((in_ptr[x] - max_val) * beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100514 sum += element;
515 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000516 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000517 }
518
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100519 if(!is_log)
520 {
521 sum_inversed = T(1) / sum;
522 }
Sang-Hoon Parka0205b92020-07-07 09:36:09 +0100523 else
524 {
525 sum = static_cast<T>(std::log(sum));
526 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000527 }
528
529 /* Normalize exponentials */
530 {
531 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000532 int x = 0;
533 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000534 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000535 auto vec_in = wrapper::vloadq(tmp_ptr + x);
536 auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100537 if(is_log)
538 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000539 normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100540 }
541 else
542 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000543 normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
544 }
545 wrapper::vstore(out_ptr + x, normalized_value);
546 }
547 /* Run remaining elements */
548 for(; x < input_width; ++x)
549 {
550 if(is_log)
551 {
552 out_ptr[x] = tmp_ptr[x] - sum;
553 }
554 else
555 {
556 out_ptr[x] = tmp_ptr[x] * sum_inversed;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100557 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000558 }
559 }
560 },
561 in_it, max_it, out_it);
562}
563} // namespace
564
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100565template <bool IS_LOG>
566NELogits1DSoftmaxKernel<IS_LOG>::NELogits1DSoftmaxKernel()
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000567 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
568{
569}
570
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100571template <bool IS_LOG>
572void NELogits1DSoftmaxKernel<IS_LOG>::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000573{
574 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
575 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000576 // Perform validation step
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000577 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info(), IS_LOG));
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000578
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000579 // Configure kernel window
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000580 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->info()->data_type());
581
582 // Output auto initialization if not yet initialized
583 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input->info()->data_type(), IS_LOG) : output->info()->quantization_info();
584 auto_init_if_empty(*output->info(), TensorInfo(*input->info()).set_quantization_info(output_quantization).reset_padding());
585
586 // Tmp auto initialization if not yet initialized
587 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input->info()->data_type();
588 auto_init_if_empty(*tmp->info(), TensorInfo(*input->info()).set_data_type(tmp_data_type).reset_padding());
589
590 // Configure kernel window
591 Window win = calculate_max_window(*max->info(), Steps());
592 Coordinates coord;
593 coord.set_num_dimensions(output->info()->num_dimensions());
594 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100595
596 switch(input->info()->data_type())
597 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000598 case DataType::QASYMM8:
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000599 _func = &logits_1d_softmax_qasymm8<qasymm8_t, IS_LOG>;
600 break;
601 case DataType::QASYMM8_SIGNED:
602 _func = &logits_1d_softmax_qasymm8<qasymm8_signed_t, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000603 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000604#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000605 case DataType::F16:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100606 _func = &logits_1d_softmax_float<float16_t, IS_LOG>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100607 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000608#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000609 case DataType::F32:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100610 _func = &logits_1d_softmax_float<float, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000611 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100612 default:
613 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100614 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100615 }
616
617 _input = input;
618 _max = max;
619 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100620 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000621 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100622
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000623 INEKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000624}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100625
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100626template <bool IS_LOG>
627Status NELogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *max,
628 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000629{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000630 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000631 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp, IS_LOG));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100632
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000633 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634}
635
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100636template <bool IS_LOG>
637void NELogits1DSoftmaxKernel<IS_LOG>::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100638{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100639 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100640 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
641 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100642
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000643 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
644 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
645
646 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
647
648 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
649
650 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100651}
652
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100653template class NELogits1DSoftmaxKernel<true>;
654template class NELogits1DSoftmaxKernel<false>;
655
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000656} // namespace arm_compute