blob: a3ecce3a1e0438d669bbab71c71446b6181d0f8f [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
Manuel Bottini21079dd2019-10-29 17:20:09 +000033#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/Window.h"
Georgios Pinitas303f0db2018-11-19 11:56:51 +000038#include "arm_compute/core/utils/misc/SaturateCast.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039
40#include <algorithm>
41#include <arm_neon.h>
42#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000045namespace arm_compute
46{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047namespace
48{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000049Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000050{
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010052 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +010053
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000054 // Validate in case of configured output
55 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010056 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000057 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000058 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
59 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010060 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000061
62 return Status{};
63}
64
65std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
66{
67 // Softmax across the x dimension
68 const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
69 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010070 auto_init_if_empty(output, output_shape, 1, input.data_type(), input.quantization_info());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000071
72 // Configure kernel window
73 const int input_width = input.valid_region().shape.x();
74 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
75 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
76
77 const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
78 output.set_valid_region(out_valid_region);
79
80 Window win = calculate_max_window(output);
81
82 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
83 AccessWindowHorizontal output_access(&output, 0, 1);
84
85 const bool window_changed = update_window_and_padding(win, input_access, output_access);
86
87 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
88 return std::make_pair(err, win);
89}
90
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000091template <typename T>
92void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
93{
94 const auto start_x = in.info()->valid_region().anchor.x();
95 const size_t input_width = in.info()->valid_region().shape.x();
96
Manuel Bottini21079dd2019-10-29 17:20:09 +000097 /** NEON vector tag type. */
98 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
99
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000100 Iterator input(&in, window);
101 Iterator output(&out, window);
102
Manuel Bottini21079dd2019-10-29 17:20:09 +0000103 constexpr int window_step_x = 16 / sizeof(T);
104 const int sum_stages = log2(window_step_x / 2);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000105 execute_window_loop(window, [&](const Coordinates &)
106 {
107 // Get pointers
108 const auto in_ptr = reinterpret_cast<const T *>(input.ptr()) + start_x;
109 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
110
111 // Init max value
Manuel Bottini21079dd2019-10-29 17:20:09 +0000112 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000113
114 // Loop over input row
Manuel Bottini21079dd2019-10-29 17:20:09 +0000115 for(const T *it = in_ptr; it < (in_ptr + input_width); it += window_step_x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000116 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000117 const auto current_value = wrapper::vloadq(it);
118 vec_max = wrapper::vmax(vec_max, current_value);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000119 }
120
Manuel Bottini21079dd2019-10-29 17:20:09 +0000121 auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
122
123 for(int i = 0; i < sum_stages; ++i)
124 {
125 carry_max = wrapper::vpmax(carry_max, carry_max);
126 }
127 const T max_val = wrapper::vgetlane(carry_max, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000128 *out_ptr = max_val;
129 },
130 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100131}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100132} // namespace
133
134NELogits1DMaxKernel::NELogits1DMaxKernel()
135 : _func(nullptr), _border_size()
136{
137}
138
139BorderSize NELogits1DMaxKernel::border_size() const
140{
141 return _border_size;
142}
143
144void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
145{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000146 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000147 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000148 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000149 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
150 // Configure kernel window
151 auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
152 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100153
154 switch(input->info()->data_type())
155 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000156 case DataType::QASYMM8:
157 _func = &logits_1d_max<qasymm8_t>;
158 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000159#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000160 case DataType::F16:
161 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100162 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000163#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000164 case DataType::F32:
165 _func = &logits_1d_max<float>;
166 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100167 default:
168 ARM_COMPUTE_ERROR("Unsupported data type.");
169 }
170
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000171 _input = input;
172 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100173
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000174 const int input_width = input->info()->valid_region().shape.x();
175 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
176 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
177
178 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
179
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000180 INEKernel::configure(win_config.second);
181}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100182
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000183Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
184{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000185 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
186
187 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
188 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100189
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000190 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100191}
192
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100193void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100195 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100196 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
197 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
198 ARM_COMPUTE_ERROR_ON(_func == nullptr);
199
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000200 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201}
202
203namespace
204{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000205Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
206 const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100207{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100208 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000209 // Check input
Anthony Barbiereaefd002018-07-20 17:49:35 +0100210 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100211 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100212
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000213 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100214
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000215 // Check max
216 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
217 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000218 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100219
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000220 // Check output if configured
221 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100222 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000223 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
224 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
225 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000226 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100227 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000229 // Check tmp if configured
230 if(tmp.total_size() != 0)
231 {
232 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
233 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000234 // We could potentially reduce tmp memory if we could predict or make an assumption
235 // on the maximum number of threads that will run in parallel.
236 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
237 }
238
239 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100240}
241
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000242std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
243 ITensorInfo &output, ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100244{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000245 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100246
247 // Output auto initialization if not yet initialized
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000248 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
249 auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100250
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000251 // Tmp auto initialization if not yet initialized
252 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
253 auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
254
255 const int input_width = input.valid_region().shape.x();
256
257 Window win = calculate_max_window(max);
258
259 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
260 AccessWindowHorizontal max_access(&input, 0, 1);
261 AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
262 AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
263
264 const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
265
266 output.set_valid_region(input.valid_region());
267
268 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
269 return std::make_pair(err, win);
270}
271
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100272template <bool is_log>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000273void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
274{
275 const int start_x = in.info()->valid_region().anchor.x();
276 const int input_width = in.info()->valid_region().shape.x();
277
Manuel Bottini21079dd2019-10-29 17:20:09 +0000278 const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
279 const auto scale_beta_vec = vdupq_n_f32(scale_beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000280
Manuel Bottini21079dd2019-10-29 17:20:09 +0000281 Iterator in_it(&in, window);
282 Iterator max_it(&max, window);
283 Iterator out_it(&out, window);
284 constexpr int vec_size = 16;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000285
286 execute_window_loop(window, [&](const Coordinates &)
287 {
288 /* Get pointers */
289 const auto in_ptr = reinterpret_cast<const qasymm8_t *>(in_it.ptr()) + start_x;
290 const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
291 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
292
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100293 float sum{};
294 float sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000295
296 /* Compute exponentials and sum */
297 {
298 /* Get max value */
299 const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
Manuel Bottini21079dd2019-10-29 17:20:09 +0000300 const auto vec_max = vdupq_n_u8(max_val);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000301
302 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000303 float32x4x4_t vec_sum =
304 {
305 vdupq_n_f32(0.f),
306 vdupq_n_f32(0.f),
307 vdupq_n_f32(0.f),
308 vdupq_n_f32(0.f),
309 };
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000310
311 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000312 int x = 0;
313 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000314 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000315 auto vec_elements = wrapper::vloadq(in_ptr + x);
316 vec_elements = vsubq_u8(vec_max, vec_elements);
317 auto vec_elements_flt = convert_uint8x16_to_float32x4x4(vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000318
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100319 if(is_log)
320 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000321 vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
322 vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
323 vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
324 vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
325 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
326 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
327 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
328 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100329 }
330 else
331 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000332 vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
333 vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
334 vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
335 vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
336 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
337 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
338 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
339 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100340 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000341
342 vst4q_f32(tmp_ptr + x, vec_elements_flt);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000343 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100344
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000345 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000346 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
347 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
348 sum_res = vpadd_f32(sum_res, sum_res);
349 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000350
351 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000352 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000353 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100354 float element{};
355 if(is_log)
356 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000357 element = (max_val - in_ptr[x]) * scale_beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100358 sum += std::exp(element);
359 }
360 else
361 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000362 element = std::exp((max_val - in_ptr[x]) * scale_beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100363 sum += element;
364 }
365
Manuel Bottini21079dd2019-10-29 17:20:09 +0000366 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000367 }
368
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100369 if(!is_log)
370 {
371 sum_inversed = 256.f / sum;
372 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000373 }
374
375 /* Normalize exponentials */
376 {
377 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000378 int x = 0;
379 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000380 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000381 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
382 uint8x16_t normalized_value{};
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100383 if(is_log)
384 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000385 const float32x4x4_t sub =
386 {
387 vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
388 vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
389 vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
390 vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
391 };
392 convert_float32x4x4_to_unit8x16(sub, normalized_value);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100393 }
394 else
395 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000396 const float32x4x4_t mul =
397 {
398 vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
399 vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
400 vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
401 vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
402 };
403 convert_float32x4x4_to_unit8x16(mul, normalized_value);
404 }
405 vst1q_u8(out_ptr + x, normalized_value);
406 }
407 /* Run remaining elements */
408 for(; x < input_width; ++x)
409 {
410 if(is_log)
411 {
412 out_ptr[x] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[x] - sum);
413 }
414 else
415 {
416 out_ptr[x] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[x] * sum_inversed);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100417 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000418 }
419 }
420 },
421 in_it, max_it, out_it);
422}
423
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100424template <typename T, bool is_log = false>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000425void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
426 ITensor &out, const float beta, const Window &window)
427{
428 const int start_x = in.info()->valid_region().anchor.x();
429 const int input_width = in.info()->valid_region().shape.x();
430
431 Iterator in_it(&in, window);
432 Iterator max_it(&max, window);
433 Iterator out_it(&out, window);
434
Manuel Bottini21079dd2019-10-29 17:20:09 +0000435 /** NEON vector tag type. */
436 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
437
438 constexpr int vec_size = 16 / sizeof(T);
439 const int sum_stages = log2(vec_size / 2);
440
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000441 execute_window_loop(window, [&](const Coordinates &)
442 {
443 /* Get pointers */
444 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
445 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
446 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
447
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100448 T sum{};
449 T sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000450
451 /* Compute exponentials and sum */
452 {
453 /* Get max value */
454 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
Manuel Bottini21079dd2019-10-29 17:20:09 +0000455 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000456
457 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000458 auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000459
460 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000461 int x = 0;
462 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000463 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000464 auto vec_elements = wrapper::vloadq(in_ptr + x);
465 vec_elements = wrapper::vsub(vec_elements, vec_max);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100466 if(is_log)
467 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000468 vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
469 vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100470 }
471 else
472 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000473 vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
474 vec_sum = wrapper::vadd(vec_sum, vec_elements);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100475 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000476 wrapper::vstore(tmp_ptr + x, vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000477 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100478
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000479 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000480 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
481 for(int i = 0; i < sum_stages; ++i)
482 {
483 sum_res = wrapper::vpadd(sum_res, sum_res);
484 }
485 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000486
487 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000488 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000489 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100490 T element{};
491
492 if(is_log)
493 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000494 element = (in_ptr[x] - max_val) * beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100495 sum += std::exp(element);
496 }
497 else
498 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000499 element = std::exp((in_ptr[x] - max_val) * beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100500 sum += element;
501 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000502 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000503 }
504
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100505 if(!is_log)
506 {
507 sum_inversed = T(1) / sum;
508 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000509 }
510
511 /* Normalize exponentials */
512 {
513 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000514 int x = 0;
515 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000516 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000517 auto vec_in = wrapper::vloadq(tmp_ptr + x);
518 auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100519 if(is_log)
520 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000521 normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100522 }
523 else
524 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000525 normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
526 }
527 wrapper::vstore(out_ptr + x, normalized_value);
528 }
529 /* Run remaining elements */
530 for(; x < input_width; ++x)
531 {
532 if(is_log)
533 {
534 out_ptr[x] = tmp_ptr[x] - sum;
535 }
536 else
537 {
538 out_ptr[x] = tmp_ptr[x] * sum_inversed;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100539 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000540 }
541 }
542 },
543 in_it, max_it, out_it);
544}
545} // namespace
546
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100547template <bool IS_LOG>
548NELogits1DSoftmaxKernel<IS_LOG>::NELogits1DSoftmaxKernel()
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000549 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
550{
551}
552
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100553template <bool IS_LOG>
554void NELogits1DSoftmaxKernel<IS_LOG>::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000555{
556 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
557 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000558 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000559 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info()));
560 // Configure kernel window
561 auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info());
562 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100563
564 switch(input->info()->data_type())
565 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000566 case DataType::QASYMM8:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100567 _func = &logits_1d_softmax_qasymm8<IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000568 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000569#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000570 case DataType::F16:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100571 _func = &logits_1d_softmax_float<float16_t, IS_LOG>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100572 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000573#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000574 case DataType::F32:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100575 _func = &logits_1d_softmax_float<float, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000576 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100577 default:
578 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100579 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100580 }
581
582 _input = input;
583 _max = max;
584 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100585 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000586 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100587
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000588 INEKernel::configure(win_config.second);
589}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100590
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100591template <bool IS_LOG>
592Status NELogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *max,
593 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000594{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000595 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
596
597 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp));
598 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100599
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000600 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100601}
602
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100603template <bool IS_LOG>
604void NELogits1DSoftmaxKernel<IS_LOG>::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100605{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100606 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100607 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
608 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100609
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000610 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
611 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
612
613 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
614
615 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
616
617 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100618}
619
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100620template class NELogits1DSoftmaxKernel<true>;
621template class NELogits1DSoftmaxKernel<false>;
622
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000623} // namespace arm_compute