blob: 95cbdf582b52ad77610cdd9bd2c4fdecc5395ad1 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
Manuel Bottini21079dd2019-10-29 17:20:09 +000033#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/Window.h"
Georgios Pinitas303f0db2018-11-19 11:56:51 +000038#include "arm_compute/core/utils/misc/SaturateCast.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039
40#include <algorithm>
41#include <arm_neon.h>
42#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000045namespace arm_compute
46{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000047template <typename float_vec_type, typename int_vec_type>
48int_vec_type convert_float_to_int(const float_vec_type &in);
49
50template <typename float_vec_type, typename int_vec_type>
51float_vec_type convert_int_to_float(const int_vec_type &in);
52
53template <>
54uint8x16_t convert_float_to_int<float32x4x4_t, uint8x16_t>(const float32x4x4_t &in)
55{
56 uint8x16_t out;
57 convert_float32x4x4_to_uint8x16(in, out);
58 return out;
59}
60
61template <>
62int8x16_t convert_float_to_int<float32x4x4_t, int8x16_t>(const float32x4x4_t &in)
63{
64 int8x16_t out;
65 convert_float32x4x4_to_int8x16(in, out);
66 return out;
67}
68
69template <>
70float32x4x4_t convert_int_to_float<float32x4x4_t, uint8x16_t>(const uint8x16_t &in)
71{
72 return convert_uint8x16_to_float32x4x4(in);
73}
74
75template <>
76float32x4x4_t convert_int_to_float<float32x4x4_t, int8x16_t>(const int8x16_t &in)
77{
78 return convert_int8x16_to_float32x4x4(in);
79}
80
Anthony Barbier6ff3b192017-09-04 18:44:23 +010081namespace
82{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000083Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000084{
Anthony Barbiereaefd002018-07-20 17:49:35 +010085 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000086 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +010087
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000088 // Validate in case of configured output
89 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010090 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000091 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000092 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
93 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010094 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000095
96 return Status{};
97}
98
99std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
100{
101 // Softmax across the x dimension
102 const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
103 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100104 auto_init_if_empty(output, output_shape, 1, input.data_type(), input.quantization_info());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000105
106 // Configure kernel window
107 const int input_width = input.valid_region().shape.x();
108 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
109 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
110
111 const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
112 output.set_valid_region(out_valid_region);
113
114 Window win = calculate_max_window(output);
115
116 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
117 AccessWindowHorizontal output_access(&output, 0, 1);
118
119 const bool window_changed = update_window_and_padding(win, input_access, output_access);
120
121 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
122 return std::make_pair(err, win);
123}
124
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000125template <typename T>
126void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
127{
128 const auto start_x = in.info()->valid_region().anchor.x();
129 const size_t input_width = in.info()->valid_region().shape.x();
130
Manuel Bottini21079dd2019-10-29 17:20:09 +0000131 /** NEON vector tag type. */
132 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
133
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000134 Iterator input(&in, window);
135 Iterator output(&out, window);
136
Manuel Bottini21079dd2019-10-29 17:20:09 +0000137 constexpr int window_step_x = 16 / sizeof(T);
138 const int sum_stages = log2(window_step_x / 2);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000139 execute_window_loop(window, [&](const Coordinates &)
140 {
141 // Get pointers
142 const auto in_ptr = reinterpret_cast<const T *>(input.ptr()) + start_x;
143 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
144
145 // Init max value
Manuel Bottini21079dd2019-10-29 17:20:09 +0000146 auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000147
148 // Loop over input row
Manuel Bottini21079dd2019-10-29 17:20:09 +0000149 for(const T *it = in_ptr; it < (in_ptr + input_width); it += window_step_x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000150 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000151 const auto current_value = wrapper::vloadq(it);
152 vec_max = wrapper::vmax(vec_max, current_value);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000153 }
154
Manuel Bottini21079dd2019-10-29 17:20:09 +0000155 auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
156
157 for(int i = 0; i < sum_stages; ++i)
158 {
159 carry_max = wrapper::vpmax(carry_max, carry_max);
160 }
161 const T max_val = wrapper::vgetlane(carry_max, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000162 *out_ptr = max_val;
163 },
164 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100165}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100166} // namespace
167
168NELogits1DMaxKernel::NELogits1DMaxKernel()
169 : _func(nullptr), _border_size()
170{
171}
172
173BorderSize NELogits1DMaxKernel::border_size() const
174{
175 return _border_size;
176}
177
178void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
179{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000180 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000181 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000182 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000183 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
184 // Configure kernel window
185 auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
186 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100187
188 switch(input->info()->data_type())
189 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000190 case DataType::QASYMM8:
191 _func = &logits_1d_max<qasymm8_t>;
192 break;
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000193 case DataType::QASYMM8_SIGNED:
194 _func = &logits_1d_max<qasymm8_signed_t>;
195 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000196#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000197 case DataType::F16:
198 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100199 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000200#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000201 case DataType::F32:
202 _func = &logits_1d_max<float>;
203 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204 default:
205 ARM_COMPUTE_ERROR("Unsupported data type.");
206 }
207
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000208 _input = input;
209 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000211 const int input_width = input->info()->valid_region().shape.x();
212 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
213 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
214
215 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
216
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000217 INEKernel::configure(win_config.second);
218}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000220Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
221{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000222 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
223
224 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
225 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100226
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000227 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228}
229
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100230void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100232 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100233 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
234 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
235 ARM_COMPUTE_ERROR_ON(_func == nullptr);
236
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000237 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100238}
239
240namespace
241{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000242Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000243 const ITensorInfo &output, const float beta, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100244{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100245 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000246 // Check input
Anthony Barbiereaefd002018-07-20 17:49:35 +0100247 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000248 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100249
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000250 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100251
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000252 // Check max
253 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
254 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000255 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100256
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000257 // Check output if configured
258 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100259 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000260 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input.data_type(), is_log) : output.quantization_info();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000261 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
262 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000263 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100264 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100265
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000266 // Check tmp if configured
267 if(tmp.total_size() != 0)
268 {
269 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
270 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000271 // We could potentially reduce tmp memory if we could predict or make an assumption
272 // on the maximum number of threads that will run in parallel.
273 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
274 }
275
276 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100277}
278
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000279std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000280 ITensorInfo &output, ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100281{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000282 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100283
284 // Output auto initialization if not yet initialized
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000285 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(input.data_type(), is_log) : output.quantization_info();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000286 auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100287
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000288 // Tmp auto initialization if not yet initialized
289 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
290 auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
291
292 const int input_width = input.valid_region().shape.x();
293
294 Window win = calculate_max_window(max);
295
296 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
297 AccessWindowHorizontal max_access(&input, 0, 1);
298 AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
299 AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
300
301 const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
302
303 output.set_valid_region(input.valid_region());
304
305 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
306 return std::make_pair(err, win);
307}
308
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000309template <typename T, bool is_log>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000310void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
311{
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000312 static_assert(std::is_same<T, qasymm8_t>::value
313 || std::is_same<T, qasymm8_signed_t>::value,
314 "quantized type should be either qasymm8_t or qasymm8_signed_t.");
315
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000316 const int start_x = in.info()->valid_region().anchor.x();
317 const int input_width = in.info()->valid_region().shape.x();
318
Manuel Bottini21079dd2019-10-29 17:20:09 +0000319 const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
320 const auto scale_beta_vec = vdupq_n_f32(scale_beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000321
Manuel Bottini21079dd2019-10-29 17:20:09 +0000322 Iterator in_it(&in, window);
323 Iterator max_it(&max, window);
324 Iterator out_it(&out, window);
325 constexpr int vec_size = 16;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000326
327 execute_window_loop(window, [&](const Coordinates &)
328 {
329 /* Get pointers */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000330 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
331 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000332 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
333
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100334 float sum{};
335 float sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000336
337 /* Compute exponentials and sum */
338 {
339 /* Get max value */
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000340 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
341 const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000342
343 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000344 float32x4x4_t vec_sum =
345 {
346 vdupq_n_f32(0.f),
347 vdupq_n_f32(0.f),
348 vdupq_n_f32(0.f),
349 vdupq_n_f32(0.f),
350 };
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000351
352 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000353 int x = 0;
354 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000355 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000356 auto vec_elements = wrapper::vloadq(in_ptr + x);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000357 vec_elements = wrapper::vsub(vec_max, vec_elements);
358 auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000359
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100360 if(is_log)
361 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000362 vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
363 vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
364 vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
365 vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
366 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
367 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
368 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
369 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100370 }
371 else
372 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000373 vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
374 vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
375 vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
376 vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
377 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
378 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
379 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
380 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100381 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000382
383 vst4q_f32(tmp_ptr + x, vec_elements_flt);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000384 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100385
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000386 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000387 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
388 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
389 sum_res = vpadd_f32(sum_res, sum_res);
390 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000391
392 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000393 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000394 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100395 float element{};
396 if(is_log)
397 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000398 element = (max_val - in_ptr[x]) * scale_beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100399 sum += std::exp(element);
400 }
401 else
402 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000403 element = std::exp((max_val - in_ptr[x]) * scale_beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100404 sum += element;
405 }
406
Manuel Bottini21079dd2019-10-29 17:20:09 +0000407 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000408 }
409
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100410 if(!is_log)
411 {
412 sum_inversed = 256.f / sum;
413 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000414 }
415
416 /* Normalize exponentials */
417 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000418 constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000419 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000420 int x = 0;
421 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000422 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000423 using int_vec_type = wrapper::traits::neon_vector_t<T, 16>;
Manuel Bottini21079dd2019-10-29 17:20:09 +0000424 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000425 int_vec_type normalized_value{};
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100426 if(is_log)
427 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000428 const float32x4x4_t sub =
429 {
430 vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
431 vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
432 vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
433 vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
434 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000435 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100436 }
437 else
438 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000439 float32x4x4_t mul =
Manuel Bottini21079dd2019-10-29 17:20:09 +0000440 {
441 vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
442 vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
443 vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
444 vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
445 };
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000446
447 if(is_qasymm8_signed)
448 {
449 const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
450 mul.val[0] = wrapper::vsub(mul.val[0], offset_vec);
451 mul.val[1] = wrapper::vsub(mul.val[1], offset_vec);
452 mul.val[2] = wrapper::vsub(mul.val[2], offset_vec);
453 mul.val[3] = wrapper::vsub(mul.val[3], offset_vec);
454 }
455
456 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000457 }
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000458 wrapper::vstore(out_ptr + x, normalized_value);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000459 }
460 /* Run remaining elements */
461 for(; x < input_width; ++x)
462 {
463 if(is_log)
464 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000465 out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum);
Manuel Bottini21079dd2019-10-29 17:20:09 +0000466 }
467 else
468 {
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000469 out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) - (is_qasymm8_signed ? 128.f : 0));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100470 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000471 }
472 }
473 },
474 in_it, max_it, out_it);
475}
476
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100477template <typename T, bool is_log = false>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000478void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
479 ITensor &out, const float beta, const Window &window)
480{
481 const int start_x = in.info()->valid_region().anchor.x();
482 const int input_width = in.info()->valid_region().shape.x();
483
484 Iterator in_it(&in, window);
485 Iterator max_it(&max, window);
486 Iterator out_it(&out, window);
487
Manuel Bottini21079dd2019-10-29 17:20:09 +0000488 /** NEON vector tag type. */
489 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
490
491 constexpr int vec_size = 16 / sizeof(T);
492 const int sum_stages = log2(vec_size / 2);
493
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000494 execute_window_loop(window, [&](const Coordinates &)
495 {
496 /* Get pointers */
497 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
498 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
499 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
500
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100501 T sum{};
502 T sum_inversed{};
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000503
504 /* Compute exponentials and sum */
505 {
506 /* Get max value */
507 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
Manuel Bottini21079dd2019-10-29 17:20:09 +0000508 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000509
510 /* Init sum to zero */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000511 auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000512
513 /* Loop over row and compute exponentials and sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000514 int x = 0;
515 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000516 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000517 auto vec_elements = wrapper::vloadq(in_ptr + x);
518 vec_elements = wrapper::vsub(vec_elements, vec_max);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100519 if(is_log)
520 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000521 vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
522 vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100523 }
524 else
525 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000526 vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
527 vec_sum = wrapper::vadd(vec_sum, vec_elements);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100528 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000529 wrapper::vstore(tmp_ptr + x, vec_elements);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000530 }
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100531
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000532 /* Reduce sum */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000533 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
534 for(int i = 0; i < sum_stages; ++i)
535 {
536 sum_res = wrapper::vpadd(sum_res, sum_res);
537 }
538 sum = wrapper::vgetlane(sum_res, 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000539
540 /* Run remaining elements */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000541 for(; x < input_width; ++x)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000542 {
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100543 T element{};
544
545 if(is_log)
546 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000547 element = (in_ptr[x] - max_val) * beta;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100548 sum += std::exp(element);
549 }
550 else
551 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000552 element = std::exp((in_ptr[x] - max_val) * beta);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100553 sum += element;
554 }
Manuel Bottini21079dd2019-10-29 17:20:09 +0000555 tmp_ptr[x] = element;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000556 }
557
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100558 if(!is_log)
559 {
560 sum_inversed = T(1) / sum;
561 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000562 }
563
564 /* Normalize exponentials */
565 {
566 /* Loop over row and compute softmax */
Manuel Bottini21079dd2019-10-29 17:20:09 +0000567 int x = 0;
568 for(; x <= (input_width - vec_size); x += vec_size)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000569 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000570 auto vec_in = wrapper::vloadq(tmp_ptr + x);
571 auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100572 if(is_log)
573 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000574 normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100575 }
576 else
577 {
Manuel Bottini21079dd2019-10-29 17:20:09 +0000578 normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
579 }
580 wrapper::vstore(out_ptr + x, normalized_value);
581 }
582 /* Run remaining elements */
583 for(; x < input_width; ++x)
584 {
585 if(is_log)
586 {
587 out_ptr[x] = tmp_ptr[x] - sum;
588 }
589 else
590 {
591 out_ptr[x] = tmp_ptr[x] * sum_inversed;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100592 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000593 }
594 }
595 },
596 in_it, max_it, out_it);
597}
598} // namespace
599
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100600template <bool IS_LOG>
601NELogits1DSoftmaxKernel<IS_LOG>::NELogits1DSoftmaxKernel()
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000602 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
603{
604}
605
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100606template <bool IS_LOG>
607void NELogits1DSoftmaxKernel<IS_LOG>::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000608{
609 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
610 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000611 // Perform validation step
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000612 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info(), IS_LOG));
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000613 // Configure kernel window
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000614 auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info(), IS_LOG);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000615 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100616
617 switch(input->info()->data_type())
618 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000619 case DataType::QASYMM8:
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000620 _func = &logits_1d_softmax_qasymm8<qasymm8_t, IS_LOG>;
621 break;
622 case DataType::QASYMM8_SIGNED:
623 _func = &logits_1d_softmax_qasymm8<qasymm8_signed_t, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000624 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000625#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000626 case DataType::F16:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100627 _func = &logits_1d_softmax_float<float16_t, IS_LOG>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100628 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000629#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000630 case DataType::F32:
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100631 _func = &logits_1d_softmax_float<float, IS_LOG>;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000632 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100633 default:
634 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100635 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100636 }
637
638 _input = input;
639 _max = max;
640 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100641 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000642 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100643
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000644 INEKernel::configure(win_config.second);
645}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100646
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100647template <bool IS_LOG>
648Status NELogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *max,
649 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000650{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000651 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
652
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000653 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp, IS_LOG));
654 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone(), IS_LOG).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100655
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000656 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100657}
658
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100659template <bool IS_LOG>
660void NELogits1DSoftmaxKernel<IS_LOG>::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100661{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100662 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100663 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
664 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100665
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000666 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
667 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
668
669 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
670
671 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
672
673 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100674}
675
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100676template class NELogits1DSoftmaxKernel<true>;
677template class NELogits1DSoftmaxKernel<false>;
678
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000679} // namespace arm_compute