blob: 97797cefdea43395774857c4e905304c244ac6af [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NESoftmaxLayerKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Utils.h"
31#include "arm_compute/core/Validate.h"
32#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/AccessWindowStatic.h"
34#include "src/core/CPP/Validate.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010035#include "src/core/NEON/NEFixedPoint.h"
36#include "src/core/NEON/NEMath.h"
37#include "src/core/NEON/wrapper/wrapper.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010038#include "src/core/helpers/AutoConfiguration.h"
39#include "src/core/helpers/WindowHelpers.h"
40#include "support/SaturateCast.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010041
42#include <algorithm>
43#include <arm_neon.h>
44#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000045#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010046
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000047namespace arm_compute
48{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000049template <typename float_vec_type, typename int_vec_type>
50int_vec_type convert_float_to_int(const float_vec_type &in);
51
52template <typename float_vec_type, typename int_vec_type>
53float_vec_type convert_int_to_float(const int_vec_type &in);
54
55template <>
56uint8x16_t convert_float_to_int<float32x4x4_t, uint8x16_t>(const float32x4x4_t &in)
57{
58 uint8x16_t out;
59 convert_float32x4x4_to_uint8x16(in, out);
60 return out;
61}
62
63template <>
64int8x16_t convert_float_to_int<float32x4x4_t, int8x16_t>(const float32x4x4_t &in)
65{
66 int8x16_t out;
67 convert_float32x4x4_to_int8x16(in, out);
68 return out;
69}
70
71template <>
72float32x4x4_t convert_int_to_float<float32x4x4_t, uint8x16_t>(const uint8x16_t &in)
73{
74 return convert_uint8x16_to_float32x4x4(in);
75}
76
77template <>
78float32x4x4_t convert_int_to_float<float32x4x4_t, int8x16_t>(const int8x16_t &in)
79{
80 return convert_int8x16_to_float32x4x4(in);
81}
82
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083namespace
84{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000085Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000086{
Anthony Barbiereaefd002018-07-20 17:49:35 +010087 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000088 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +010089
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000090 // Validate in case of configured output
91 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000093 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000094 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
95 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010096 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000097
98 return Status{};
99}
100
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000101template <typename T>
102void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
103{
Manuel Bottini21079dd2019-10-29 17:20:09 +0000104 /** NEON vector tag type. */
105 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
106
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000107 constexpr int window_step_x = 16 / sizeof(T);
108 const auto window_start_x = static_cast<int>(window.x().start());
109 const auto window_end_x = static_cast<int>(window.x().end());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000110
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000111 Window win{ window };
112 win.set(Window::DimX, Window::Dimension(0, 1, 1));
113 Iterator input(&in, win);
114 Iterator output(&out, win);
115
116 const int sum_stages = log2(window_step_x / 2);
117 execute_window_loop(win, [&](const Coordinates &)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000118 {
119 // Get pointers
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000120 const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000121 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
122
123 // Init max value
Manuel Bottini21079dd2019-10-29 17:20:09 +0000124 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000125 int x = window_start_x;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000126
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000127 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000128 {
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000129 const auto current_value = wrapper::vloadq(in_ptr + x);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000130 vec_max = wrapper::vmax(vec_max, current_value);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000131 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000132 auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
133
134 for(int i = 0; i < sum_stages; ++i)
135 {
136 carry_max = wrapper::vpmax(carry_max, carry_max);
137 }
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000138 T max_val = wrapper::vgetlane(carry_max, 0);
139
140 // Compute left-over elements
141 for(; x < window_end_x; ++x)
142 {
143 max_val = *(in_ptr + x) > max_val ? *(in_ptr + x) : max_val;
144 }
145
146 *out_ptr = max_val;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000147 },
148 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100149}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100150} // namespace
151
152NELogits1DMaxKernel::NELogits1DMaxKernel()
153 : _func(nullptr), _border_size()
154{
155}
156
157BorderSize NELogits1DMaxKernel::border_size() const
158{
159 return _border_size;
160}
161
162void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
163{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000164 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000165 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000166 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000167 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
168 // Configure kernel window
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000169
170 // Softmax across the x dimension
171 const TensorShape output_shape = TensorShape(input->info()->tensor_shape()).set(0, 1);
172 // Output auto initialization if not yet initialized
173 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->quantization_info());
174
175 Window win = calculate_max_window(*input->info(), Steps());
176 Coordinates coord;
177 coord.set_num_dimensions(output->info()->num_dimensions());
178 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100179
180 switch(input->info()->data_type())
181 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000182 case DataType::QASYMM8:
183 _func = &logits_1d_max<qasymm8_t>;
184 break;
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000185 case DataType::QASYMM8_SIGNED:
186 _func = &logits_1d_max<qasymm8_signed_t>;
187 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000188#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000189 case DataType::F16:
190 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100191 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000192#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000193 case DataType::F32:
194 _func = &logits_1d_max<float>;
195 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100196 default:
197 ARM_COMPUTE_ERROR("Unsupported data type.");
198 }
199
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000200 _input = input;
201 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100202
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000203 const int input_width = input->info()->valid_region().shape.x();
204 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
205 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
206
207 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
208
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000209 INEKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000210}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100211
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000212Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
213{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000214 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000215 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000217 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100218}
219
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100220void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100221{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100222 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100223 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
224 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
225 ARM_COMPUTE_ERROR_ON(_func == nullptr);
226
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000227 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228}
229
230namespace
231{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000232Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000233 const ITensorInfo &output, const float beta, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100235 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000236 // Check input
Anthony Barbiereaefd002018-07-20 17:49:35 +0100237 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000238 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100239
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000240 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100241
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000242 // Check max
243 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
244 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000245 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100246
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000247 // Check output if configured
248 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100249 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000250 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input.data_type(), is_log) : output.quantization_info();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000251 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
252 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000253 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100254 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000256 // Check tmp if configured
257 if(tmp.total_size() != 0)
258 {
259 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
260 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000261 // We could potentially reduce tmp memory if we could predict or make an assumption
262 // on the maximum number of threads that will run in parallel.
263 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
264 }
265
266 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267}
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000268template <typename T, bool is_log>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000269void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
270{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000271 static_assert(std::is_same<T, qasymm8_t>::value
272 || std::is_same<T, qasymm8_signed_t>::value,
273 "quantized type should be either qasymm8_t or qasymm8_signed_t.");
274
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000275 const int start_x = in.info()->valid_region().anchor.x();
276 const int input_width = in.info()->valid_region().shape.x();
277
Manuel Bottini21079dd2019-10-29 17:20:09 +0000278 const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
279 const auto scale_beta_vec = vdupq_n_f32(scale_beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000280
Manuel Bottini21079dd2019-10-29 17:20:09 +0000281 Iterator in_it(&in, window);
282 Iterator max_it(&max, window);
283 Iterator out_it(&out, window);
284 constexpr int vec_size = 16;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000285
286 execute_window_loop(window, [&](const Coordinates &)
287 {
288 /* Get pointers */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000289 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
290 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000291 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
292
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100293 float sum{};
294 float sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000295
296 /* Compute exponentials and sum */
297 {
298 /* Get max value */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000299 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
300 const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000301
302 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000303 float32x4x4_t vec_sum =
304 {
305 vdupq_n_f32(0.f),
306 vdupq_n_f32(0.f),
307 vdupq_n_f32(0.f),
308 vdupq_n_f32(0.f),
309 };
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000310
311 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000312 int x = 0;
313 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000314 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000315 auto vec_elements = wrapper::vloadq(in_ptr + x);
Georgios Pinitas2cd7a372020-05-12 21:03:56 +0100316 vec_elements = wrapper::vqsub(vec_max, vec_elements);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000317 auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000318
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100319 if(is_log)
320 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000321 vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
322 vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
323 vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
324 vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
325 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
326 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
327 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
328 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100329 }
330 else
331 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000332 vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
333 vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
334 vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
335 vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
336 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
337 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
338 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
339 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100340 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000341
342 vst4q_f32(tmp_ptr + x, vec_elements_flt);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000343 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100344
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000345 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000346 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
347 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
348 sum_res = vpadd_f32(sum_res, sum_res);
349 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000350
351 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000352 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000353 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100354 float element{};
355 if(is_log)
356 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000357 element = (max_val - in_ptr[x]) * scale_beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100358 sum += std::exp(element);
359 }
360 else
361 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000362 element = std::exp((max_val - in_ptr[x]) * scale_beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100363 sum += element;
364 }
365
Manuel Bottini21079dd2019-10-29 17:20:09 +0000366 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000367 }
368
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100369 if(!is_log)
370 {
371 sum_inversed = 256.f / sum;
372 }
Sang-Hoon Parka0205b92020-07-07 09:36:09 +0100373 else
374 {
375 sum = std::log(sum);
376 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000377 }
378
379 /* Normalize exponentials */
380 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000381 constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000382 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000383 int x = 0;
384 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000385 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000386 using int_vec_type = wrapper::traits::neon_vector_t<T, 16>;
Manuel Bottini21079dd2019-10-29 17:20:09 +0000387 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000388 int_vec_type normalized_value{};
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100389 if(is_log)
390 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000391 const float32x4x4_t sub =
392 {
393 vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
394 vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
395 vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
396 vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
397 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000398 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100399 }
400 else
401 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000402 float32x4x4_t mul =
Manuel Bottini21079dd2019-10-29 17:20:09 +0000403 {
404 vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
405 vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
406 vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
407 vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
408 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000409
410 if(is_qasymm8_signed)
411 {
412 const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
413 mul.val[0] = wrapper::vsub(mul.val[0], offset_vec);
414 mul.val[1] = wrapper::vsub(mul.val[1], offset_vec);
415 mul.val[2] = wrapper::vsub(mul.val[2], offset_vec);
416 mul.val[3] = wrapper::vsub(mul.val[3], offset_vec);
417 }
418
419 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000420 }
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000421 wrapper::vstore(out_ptr + x, normalized_value);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000422 }
423 /* Run remaining elements */
424 for(; x < input_width; ++x)
425 {
426 if(is_log)
427 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000428 out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000429 }
430 else
431 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000432 out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) - (is_qasymm8_signed ? 128.f : 0));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100433 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000434 }
435 }
436 },
437 in_it, max_it, out_it);
438}
439
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100440template <typename T, bool is_log = false>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000441void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
442 ITensor &out, const float beta, const Window &window)
443{
444 const int start_x = in.info()->valid_region().anchor.x();
445 const int input_width = in.info()->valid_region().shape.x();
446
447 Iterator in_it(&in, window);
448 Iterator max_it(&max, window);
449 Iterator out_it(&out, window);
450
Manuel Bottini21079dd2019-10-29 17:20:09 +0000451 /** NEON vector tag type. */
452 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
453
454 constexpr int vec_size = 16 / sizeof(T);
455 const int sum_stages = log2(vec_size / 2);
456
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000457 execute_window_loop(window, [&](const Coordinates &)
458 {
459 /* Get pointers */
460 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
461 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
462 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
463
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100464 T sum{};
465 T sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000466
467 /* Compute exponentials and sum */
468 {
469 /* Get max value */
470 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
Manuel Bottini21079dd2019-10-29 17:20:09 +0000471 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000472
473 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000474 auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000475
476 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000477 int x = 0;
478 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000479 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000480 auto vec_elements = wrapper::vloadq(in_ptr + x);
481 vec_elements = wrapper::vsub(vec_elements, vec_max);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100482 if(is_log)
483 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000484 vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
485 vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100486 }
487 else
488 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000489 vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
490 vec_sum = wrapper::vadd(vec_sum, vec_elements);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100491 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000492 wrapper::vstore(tmp_ptr + x, vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000493 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100494
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000495 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000496 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
497 for(int i = 0; i < sum_stages; ++i)
498 {
499 sum_res = wrapper::vpadd(sum_res, sum_res);
500 }
501 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000502
503 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000504 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000505 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100506 T element{};
507
508 if(is_log)
509 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000510 element = (in_ptr[x] - max_val) * beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100511 sum += std::exp(element);
512 }
513 else
514 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000515 element = std::exp((in_ptr[x] - max_val) * beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100516 sum += element;
517 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000518 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000519 }
520
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100521 if(!is_log)
522 {
523 sum_inversed = T(1) / sum;
524 }
Sang-Hoon Parka0205b92020-07-07 09:36:09 +0100525 else
526 {
527 sum = static_cast<T>(std::log(sum));
528 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000529 }
530
531 /* Normalize exponentials */
532 {
533 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000534 int x = 0;
535 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000536 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000537 auto vec_in = wrapper::vloadq(tmp_ptr + x);
538 auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100539 if(is_log)
540 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000541 normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100542 }
543 else
544 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000545 normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
546 }
547 wrapper::vstore(out_ptr + x, normalized_value);
548 }
549 /* Run remaining elements */
550 for(; x < input_width; ++x)
551 {
552 if(is_log)
553 {
554 out_ptr[x] = tmp_ptr[x] - sum;
555 }
556 else
557 {
558 out_ptr[x] = tmp_ptr[x] * sum_inversed;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100559 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000560 }
561 }
562 },
563 in_it, max_it, out_it);
564}
565} // namespace
566
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100567template <bool IS_LOG>
568NELogits1DSoftmaxKernel<IS_LOG>::NELogits1DSoftmaxKernel()
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000569 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
570{
571}
572
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100573template <bool IS_LOG>
574void NELogits1DSoftmaxKernel<IS_LOG>::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000575{
576 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
577 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000578 // Perform validation step
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000579 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info(), IS_LOG));
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000580
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000581 // Configure kernel window
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000582 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->info()->data_type());
583
584 // Output auto initialization if not yet initialized
585 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input->info()->data_type(), IS_LOG) : output->info()->quantization_info();
586 auto_init_if_empty(*output->info(), TensorInfo(*input->info()).set_quantization_info(output_quantization).reset_padding());
587
588 // Tmp auto initialization if not yet initialized
589 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input->info()->data_type();
590 auto_init_if_empty(*tmp->info(), TensorInfo(*input->info()).set_data_type(tmp_data_type).reset_padding());
591
592 // Configure kernel window
593 Window win = calculate_max_window(*max->info(), Steps());
594 Coordinates coord;
595 coord.set_num_dimensions(output->info()->num_dimensions());
596 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100597
598 switch(input->info()->data_type())
599 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000600 case DataType::QASYMM8:
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000601 _func = &logits_1d_softmax_qasymm8<qasymm8_t, IS_LOG>;
602 break;
603 case DataType::QASYMM8_SIGNED:
604 _func = &logits_1d_softmax_qasymm8<qasymm8_signed_t, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000605 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000606#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000607 case DataType::F16:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100608 _func = &logits_1d_softmax_float<float16_t, IS_LOG>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100609 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000610#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000611 case DataType::F32:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100612 _func = &logits_1d_softmax_float<float, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000613 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100614 default:
615 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100616 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100617 }
618
619 _input = input;
620 _max = max;
621 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100622 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000623 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100624
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000625 INEKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000626}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100627
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100628template <bool IS_LOG>
629Status NELogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *max,
630 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000631{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000632 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000633 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp, IS_LOG));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000635 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100636}
637
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100638template <bool IS_LOG>
639void NELogits1DSoftmaxKernel<IS_LOG>::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100640{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100641 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100642 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
643 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100644
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000645 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
646 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
647
648 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
649
650 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
651
652 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100653}
654
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100655template class NELogits1DSoftmaxKernel<true>;
656template class NELogits1DSoftmaxKernel<false>;
657
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000658} // namespace arm_compute