blob: 790c8bacc5928746ecbe6f9574f078dc76fad7de [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou2dc7e402020-02-28 14:41:35 +00002 * Copyright (c) 2017-2020 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
Manuel Bottini21079dd2019-10-29 17:20:09 +000033#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/Window.h"
Georgios Pinitas303f0db2018-11-19 11:56:51 +000038#include "arm_compute/core/utils/misc/SaturateCast.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039
40#include <algorithm>
41#include <arm_neon.h>
42#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000045namespace arm_compute
46{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000047template <typename float_vec_type, typename int_vec_type>
48int_vec_type convert_float_to_int(const float_vec_type &in);
49
50template <typename float_vec_type, typename int_vec_type>
51float_vec_type convert_int_to_float(const int_vec_type &in);
52
53template <>
54uint8x16_t convert_float_to_int<float32x4x4_t, uint8x16_t>(const float32x4x4_t &in)
55{
56 uint8x16_t out;
57 convert_float32x4x4_to_uint8x16(in, out);
58 return out;
59}
60
61template <>
62int8x16_t convert_float_to_int<float32x4x4_t, int8x16_t>(const float32x4x4_t &in)
63{
64 int8x16_t out;
65 convert_float32x4x4_to_int8x16(in, out);
66 return out;
67}
68
69template <>
70float32x4x4_t convert_int_to_float<float32x4x4_t, uint8x16_t>(const uint8x16_t &in)
71{
72 return convert_uint8x16_to_float32x4x4(in);
73}
74
75template <>
76float32x4x4_t convert_int_to_float<float32x4x4_t, int8x16_t>(const int8x16_t &in)
77{
78 return convert_int8x16_to_float32x4x4(in);
79}
80
Anthony Barbier6ff3b192017-09-04 18:44:23 +010081namespace
82{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000083Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000084{
Anthony Barbiereaefd002018-07-20 17:49:35 +010085 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000086 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +010087
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000088 // Validate in case of configured output
89 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010090 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000091 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000092 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
93 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010094 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000095
96 return Status{};
97}
98
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000099template <typename T>
100void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
101{
Manuel Bottini21079dd2019-10-29 17:20:09 +0000102 /** NEON vector tag type. */
103 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
104
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000105 constexpr int window_step_x = 16 / sizeof(T);
106 const auto window_start_x = static_cast<int>(window.x().start());
107 const auto window_end_x = static_cast<int>(window.x().end());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000108
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000109 Window win{ window };
110 win.set(Window::DimX, Window::Dimension(0, 1, 1));
111 Iterator input(&in, win);
112 Iterator output(&out, win);
113
114 const int sum_stages = log2(window_step_x / 2);
115 execute_window_loop(win, [&](const Coordinates &)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000116 {
117 // Get pointers
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000118 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000119 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
120
121 // Init max value
Manuel Bottini21079dd2019-10-29 17:20:09 +0000122 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000123 int x = window_start_x;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000124
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000125 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000126 {
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000127 const auto current_value = wrapper::vloadq(in_ptr + x);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000128 vec_max = wrapper::vmax(vec_max, current_value);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000129 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000130 auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
131
132 for(int i = 0; i < sum_stages; ++i)
133 {
134 carry_max = wrapper::vpmax(carry_max, carry_max);
135 }
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000136 T max_val = wrapper::vgetlane(carry_max, 0);
137
138 // Compute left-over elements
139 for(; x < window_end_x; ++x)
140 {
141 max_val = *(in_ptr + x) > max_val ? *(in_ptr + x) : max_val;
142 }
143
144 *out_ptr = max_val;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000145 },
146 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100147}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100148} // namespace
149
150NELogits1DMaxKernel::NELogits1DMaxKernel()
151 : _func(nullptr), _border_size()
152{
153}
154
155BorderSize NELogits1DMaxKernel::border_size() const
156{
157 return _border_size;
158}
159
160void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
161{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000162 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000163 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000164 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000165 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
166 // Configure kernel window
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000167
168 // Softmax across the x dimension
169 const TensorShape output_shape = TensorShape(input->info()->tensor_shape()).set(0, 1);
170 // Output auto initialization if not yet initialized
171 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->quantization_info());
172
173 Window win = calculate_max_window(*input->info(), Steps());
174 Coordinates coord;
175 coord.set_num_dimensions(output->info()->num_dimensions());
176 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100177
178 switch(input->info()->data_type())
179 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000180 case DataType::QASYMM8:
181 _func = &logits_1d_max<qasymm8_t>;
182 break;
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000183 case DataType::QASYMM8_SIGNED:
184 _func = &logits_1d_max<qasymm8_signed_t>;
185 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000186#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000187 case DataType::F16:
188 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100189 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000190#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000191 case DataType::F32:
192 _func = &logits_1d_max<float>;
193 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194 default:
195 ARM_COMPUTE_ERROR("Unsupported data type.");
196 }
197
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000198 _input = input;
199 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100200
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000201 const int input_width = input->info()->valid_region().shape.x();
202 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
203 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
204
205 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
206
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000207 INEKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000208}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000210Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
211{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000212 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000213 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100214
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000215 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216}
217
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100218void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100220 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100221 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
222 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
223 ARM_COMPUTE_ERROR_ON(_func == nullptr);
224
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000225 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100226}
227
228namespace
229{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000230Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000231 const ITensorInfo &output, const float beta, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100233 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000234 // Check input
Anthony Barbiereaefd002018-07-20 17:49:35 +0100235 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000236 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100237
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000238 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100239
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000240 // Check max
241 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
242 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000243 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100244
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000245 // Check output if configured
246 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100247 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000248 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input.data_type(), is_log) : output.quantization_info();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000249 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
250 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000251 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100252 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000254 // Check tmp if configured
255 if(tmp.total_size() != 0)
256 {
257 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
258 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000259 // We could potentially reduce tmp memory if we could predict or make an assumption
260 // on the maximum number of threads that will run in parallel.
261 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
262 }
263
264 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100265}
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000266template <typename T, bool is_log>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000267void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
268{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000269 static_assert(std::is_same<T, qasymm8_t>::value
270 || std::is_same<T, qasymm8_signed_t>::value,
271 "quantized type should be either qasymm8_t or qasymm8_signed_t.");
272
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000273 const int start_x = in.info()->valid_region().anchor.x();
274 const int input_width = in.info()->valid_region().shape.x();
275
Manuel Bottini21079dd2019-10-29 17:20:09 +0000276 const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
277 const auto scale_beta_vec = vdupq_n_f32(scale_beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000278
Manuel Bottini21079dd2019-10-29 17:20:09 +0000279 Iterator in_it(&in, window);
280 Iterator max_it(&max, window);
281 Iterator out_it(&out, window);
282 constexpr int vec_size = 16;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000283
284 execute_window_loop(window, [&](const Coordinates &)
285 {
286 /* Get pointers */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000287 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
288 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000289 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
290
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100291 float sum{};
292 float sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000293
294 /* Compute exponentials and sum */
295 {
296 /* Get max value */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000297 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
298 const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000299
300 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000301 float32x4x4_t vec_sum =
302 {
303 vdupq_n_f32(0.f),
304 vdupq_n_f32(0.f),
305 vdupq_n_f32(0.f),
306 vdupq_n_f32(0.f),
307 };
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000308
309 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000310 int x = 0;
311 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000312 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000313 auto vec_elements = wrapper::vloadq(in_ptr + x);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000314 vec_elements = wrapper::vsub(vec_max, vec_elements);
315 auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000316
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100317 if(is_log)
318 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000319 vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
320 vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
321 vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
322 vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
323 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
324 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
325 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
326 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100327 }
328 else
329 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000330 vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
331 vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
332 vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
333 vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
334 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
335 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
336 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
337 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100338 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000339
340 vst4q_f32(tmp_ptr + x, vec_elements_flt);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000341 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100342
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000343 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000344 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
345 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
346 sum_res = vpadd_f32(sum_res, sum_res);
347 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000348
349 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000350 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000351 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100352 float element{};
353 if(is_log)
354 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000355 element = (max_val - in_ptr[x]) * scale_beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100356 sum += std::exp(element);
357 }
358 else
359 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000360 element = std::exp((max_val - in_ptr[x]) * scale_beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100361 sum += element;
362 }
363
Manuel Bottini21079dd2019-10-29 17:20:09 +0000364 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000365 }
366
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100367 if(!is_log)
368 {
369 sum_inversed = 256.f / sum;
370 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000371 }
372
373 /* Normalize exponentials */
374 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000375 constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000376 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000377 int x = 0;
378 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000379 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000380 using int_vec_type = wrapper::traits::neon_vector_t<T, 16>;
Manuel Bottini21079dd2019-10-29 17:20:09 +0000381 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000382 int_vec_type normalized_value{};
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100383 if(is_log)
384 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000385 const float32x4x4_t sub =
386 {
387 vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
388 vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
389 vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
390 vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
391 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000392 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100393 }
394 else
395 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000396 float32x4x4_t mul =
Manuel Bottini21079dd2019-10-29 17:20:09 +0000397 {
398 vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
399 vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
400 vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
401 vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
402 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000403
404 if(is_qasymm8_signed)
405 {
406 const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
407 mul.val[0] = wrapper::vsub(mul.val[0], offset_vec);
408 mul.val[1] = wrapper::vsub(mul.val[1], offset_vec);
409 mul.val[2] = wrapper::vsub(mul.val[2], offset_vec);
410 mul.val[3] = wrapper::vsub(mul.val[3], offset_vec);
411 }
412
413 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000414 }
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000415 wrapper::vstore(out_ptr + x, normalized_value);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000416 }
417 /* Run remaining elements */
418 for(; x < input_width; ++x)
419 {
420 if(is_log)
421 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000422 out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000423 }
424 else
425 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000426 out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) - (is_qasymm8_signed ? 128.f : 0));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100427 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000428 }
429 }
430 },
431 in_it, max_it, out_it);
432}
433
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100434template <typename T, bool is_log = false>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000435void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
436 ITensor &out, const float beta, const Window &window)
437{
438 const int start_x = in.info()->valid_region().anchor.x();
439 const int input_width = in.info()->valid_region().shape.x();
440
441 Iterator in_it(&in, window);
442 Iterator max_it(&max, window);
443 Iterator out_it(&out, window);
444
Manuel Bottini21079dd2019-10-29 17:20:09 +0000445 /** NEON vector tag type. */
446 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
447
448 constexpr int vec_size = 16 / sizeof(T);
449 const int sum_stages = log2(vec_size / 2);
450
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000451 execute_window_loop(window, [&](const Coordinates &)
452 {
453 /* Get pointers */
454 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
455 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
456 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
457
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100458 T sum{};
459 T sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000460
461 /* Compute exponentials and sum */
462 {
463 /* Get max value */
464 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
Manuel Bottini21079dd2019-10-29 17:20:09 +0000465 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000466
467 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000468 auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000469
470 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000471 int x = 0;
472 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000473 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000474 auto vec_elements = wrapper::vloadq(in_ptr + x);
475 vec_elements = wrapper::vsub(vec_elements, vec_max);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100476 if(is_log)
477 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000478 vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
479 vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100480 }
481 else
482 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000483 vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
484 vec_sum = wrapper::vadd(vec_sum, vec_elements);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100485 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000486 wrapper::vstore(tmp_ptr + x, vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000487 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100488
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000489 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000490 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
491 for(int i = 0; i < sum_stages; ++i)
492 {
493 sum_res = wrapper::vpadd(sum_res, sum_res);
494 }
495 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000496
497 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000498 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000499 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100500 T element{};
501
502 if(is_log)
503 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000504 element = (in_ptr[x] - max_val) * beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100505 sum += std::exp(element);
506 }
507 else
508 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000509 element = std::exp((in_ptr[x] - max_val) * beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100510 sum += element;
511 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000512 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000513 }
514
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100515 if(!is_log)
516 {
517 sum_inversed = T(1) / sum;
518 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000519 }
520
521 /* Normalize exponentials */
522 {
523 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000524 int x = 0;
525 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000526 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000527 auto vec_in = wrapper::vloadq(tmp_ptr + x);
528 auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100529 if(is_log)
530 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000531 normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100532 }
533 else
534 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000535 normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
536 }
537 wrapper::vstore(out_ptr + x, normalized_value);
538 }
539 /* Run remaining elements */
540 for(; x < input_width; ++x)
541 {
542 if(is_log)
543 {
544 out_ptr[x] = tmp_ptr[x] - sum;
545 }
546 else
547 {
548 out_ptr[x] = tmp_ptr[x] * sum_inversed;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100549 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000550 }
551 }
552 },
553 in_it, max_it, out_it);
554}
555} // namespace
556
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100557template <bool IS_LOG>
558NELogits1DSoftmaxKernel<IS_LOG>::NELogits1DSoftmaxKernel()
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000559 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
560{
561}
562
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100563template <bool IS_LOG>
564void NELogits1DSoftmaxKernel<IS_LOG>::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000565{
566 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
567 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000568 // Perform validation step
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000569 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info(), IS_LOG));
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000570
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000571 // Configure kernel window
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000572 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->info()->data_type());
573
574 // Output auto initialization if not yet initialized
575 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input->info()->data_type(), IS_LOG) : output->info()->quantization_info();
576 auto_init_if_empty(*output->info(), TensorInfo(*input->info()).set_quantization_info(output_quantization).reset_padding());
577
578 // Tmp auto initialization if not yet initialized
579 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input->info()->data_type();
580 auto_init_if_empty(*tmp->info(), TensorInfo(*input->info()).set_data_type(tmp_data_type).reset_padding());
581
582 // Configure kernel window
583 Window win = calculate_max_window(*max->info(), Steps());
584 Coordinates coord;
585 coord.set_num_dimensions(output->info()->num_dimensions());
586 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100587
588 switch(input->info()->data_type())
589 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000590 case DataType::QASYMM8:
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000591 _func = &logits_1d_softmax_qasymm8<qasymm8_t, IS_LOG>;
592 break;
593 case DataType::QASYMM8_SIGNED:
594 _func = &logits_1d_softmax_qasymm8<qasymm8_signed_t, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000595 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000596#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000597 case DataType::F16:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100598 _func = &logits_1d_softmax_float<float16_t, IS_LOG>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100599 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000600#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000601 case DataType::F32:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100602 _func = &logits_1d_softmax_float<float, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000603 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100604 default:
605 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100606 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100607 }
608
609 _input = input;
610 _max = max;
611 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100612 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000613 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100614
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000615 INEKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000616}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100617
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100618template <bool IS_LOG>
619Status NELogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *max,
620 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000621{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000622 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000623 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp, IS_LOG));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100624
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000625 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100626}
627
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100628template <bool IS_LOG>
629void NELogits1DSoftmaxKernel<IS_LOG>::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100630{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100631 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100632 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
633 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000635 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
636 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
637
638 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
639
640 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
641
642 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100643}
644
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100645template class NELogits1DSoftmaxKernel<true>;
646template class NELogits1DSoftmaxKernel<false>;
647
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000648} // namespace arm_compute