blob: 9946f002dea8a7eadf867764c82b3806b535ed31 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/NEMath.h"
32#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Utils.h"
34#include "arm_compute/core/Validate.h"
35#include "arm_compute/core/Window.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000036#include "arm_compute/core/utils/misc/Utility.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037
38#include <algorithm>
39#include <arm_neon.h>
40#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000041#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043namespace arm_compute
44{
45template <typename T, int N>
46struct vec_n_type;
47
48#define DECLARE_NEON_VEC_TYPE(T, N, V) \
49 template <> \
50 struct vec_n_type<T, N> \
51 { \
52 using type = V; \
53 };
54
55DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
56DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
57
58DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
59DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
60
61DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
62DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
63
64DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
65DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
66
67DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
68DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
69
70DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
71DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
72
73#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
74DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
75DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
76#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
77
78DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
79DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
80
81template <typename T, int N>
82using vec_n_t = typename vec_n_type<T, N>::type;
83
84template <typename T, int N>
85using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
86
87template <typename T>
88using vec_16_byte_t = vec_n_byte_t<T, 16>;
89
90template <typename T>
91using vec_8_byte_t = vec_n_byte_t<T, 8>;
92
93template <typename T>
94using const_ptr_t = const T *;
95
96template <typename T>
97using ptr_t = T *;
98
99#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
100 template <int lane> \
101 TYPE vget_lane(vec_8_byte_t<TYPE> vec); \
102 template <int lane> \
103 TYPE vget_lane(vec_16_byte_t<TYPE> vec);
104
105FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
106FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
107FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
108FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
109FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
110FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
111#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
112FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
113#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
114FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
115template <int lane>
116float vget_lane(float32x4x4_t vec);
117
118template <typename V>
119using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
120
121template <typename V>
122constexpr size_t vec_size_of(const V &vec)
123{
124 return sizeof(vec) / sizeof(elem_type_t<V>);
125}
126
127template <typename V>
128V vdup_n(elem_type_t<V> val);
129template <typename V>
130V vld(const_ptr_t<elem_type_t<V>> ptr);
131
132#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \
133 template <> \
134 inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val) \
135 { \
136 return vdup_n_##TAG(val); \
137 } \
138 template <> \
139 inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val) \
140 { \
141 return vdupq_n_##TAG(val); \
142 } \
143 template <> \
144 inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
145 { \
146 return vld1_##TAG(ptr); \
147 } \
148 template <> \
149 inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
150 { \
151 return vld1q_##TAG(ptr); \
152 } \
153 inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec) \
154 { \
155 vst1_##TAG(ptr, vec); \
156 } \
157 inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec) \
158 { \
159 vst1q_##TAG(ptr, vec); \
160 } \
161 inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
162 { \
163 return vmaxq_##TAG(a, b); \
164 } \
165 inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
166 { \
167 return vpmax_##TAG(a, b); \
168 } \
169 inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec) \
170 { \
171 return vget_low_##TAG(vec); \
172 } \
173 inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec) \
174 { \
175 return vget_high_##TAG(vec); \
176 } \
177 template <int lane> \
178 inline TYPE vget_lane(vec_8_byte_t<TYPE> vec) \
179 { \
180 static_assert(lane >= 0, "lane is out of bounds"); \
181 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
182 return vget_lane_##TAG(vec, lane); \
183 } \
184 template <int lane> \
185 inline TYPE vget_lane(vec_16_byte_t<TYPE> vec) \
186 { \
187 static_assert(lane >= 0, "lane is out of bounds"); \
188 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
189 return vgetq_lane_##TAG(vec, lane); \
190 }
191
192template <typename T>
193T sqadd(T a, T b);
194template <typename T>
195T sqsub(T a, T b);
196template <typename T>
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100197T sqmul(T a, T b);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000198
199#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \
200 inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
201 { \
202 return vadd_##TAG(a, b); \
203 } \
204 inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
205 { \
206 return vaddq_##TAG(a, b); \
207 } \
208 inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
209 { \
210 return vsubq_##TAG(a, b); \
211 } \
212 inline vec_16_byte_t<TYPE> vexp(vec_16_byte_t<TYPE> vec) \
213 { \
214 return vexpq_##TAG(vec); \
215 } \
216 inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val) \
217 { \
218 return vmulq_n_##TAG(vec, val); \
219 }
220
221DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
222DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
223DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
224DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
225DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
226DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
227#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
228DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
229#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
230DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
231
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000232#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
233DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
234#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
235DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
236
237template <typename VO, typename VI>
238VO vcvt(VI vec);
239
240template <>
241float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
242{
243 const auto low = vmovl_u8(vget_low(vec));
244 const auto high = vmovl_u8(vget_high(vec));
245 float32x4x4_t res = { {
246 vcvtq_f32_u32(vmovl_u16(vget_low(low))),
247 vcvtq_f32_u32(vmovl_u16(vget_high(low))),
248 vcvtq_f32_u32(vmovl_u16(vget_low(high))),
249 vcvtq_f32_u32(vmovl_u16(vget_high(high)))
250 }
251 };
252 return res;
253}
254
255template <>
256uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
257{
258 uint16x8x2_t resU16 = { {
259 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
260 vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
261 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
262 vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
263 }
264 };
265
266 uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
267 return res;
268}
269
270float32x4x4_t vexp(float32x4x4_t vec)
271{
272 float32x4x4_t res = { {
273 vexpq_f32(vec.val[0]),
274 vexpq_f32(vec.val[1]),
275 vexpq_f32(vec.val[2]),
276 vexpq_f32(vec.val[3])
277 }
278 };
279 return res;
280}
281
282template <>
283float32x4x4_t vdup_n<float32x4x4_t>(float val)
284{
285 float32x4x4_t res = { {
286 vdupq_n_f32(val),
287 vdupq_n_f32(val),
288 vdupq_n_f32(val),
289 vdupq_n_f32(val)
290 }
291 };
292 return res;
293}
294
295float32x4x4_t vmul_n(float32x4x4_t vec, float val)
296{
297 float32x4x4_t res = { {
298 vmulq_n_f32(vec.val[0], val),
299 vmulq_n_f32(vec.val[1], val),
300 vmulq_n_f32(vec.val[2], val),
301 vmulq_n_f32(vec.val[3], val)
302 }
303 };
304 return res;
305}
306
307float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
308{
309 float32x4x4_t res = { {
310 vaddq_f32(a.val[0], b.val[0]),
311 vaddq_f32(a.val[1], b.val[1]),
312 vaddq_f32(a.val[2], b.val[2]),
313 vaddq_f32(a.val[3], b.val[3])
314 }
315 };
316 return res;
317}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100318
319namespace
320{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000321Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000322{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000323#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100324 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000325#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100326 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F32);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000327#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100328
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000329 // Validate in case of configured output
330 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100331 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000332 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000333 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
334 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100335 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000336
337 return Status{};
338}
339
340std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
341{
342 // Softmax across the x dimension
343 const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
344 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100345 auto_init_if_empty(output, output_shape, 1, input.data_type(), input.quantization_info());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000346
347 // Configure kernel window
348 const int input_width = input.valid_region().shape.x();
349 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
350 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
351
352 const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
353 output.set_valid_region(out_valid_region);
354
355 Window win = calculate_max_window(output);
356
357 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
358 AccessWindowHorizontal output_access(&output, 0, 1);
359
360 const bool window_changed = update_window_and_padding(win, input_access, output_access);
361
362 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
363 return std::make_pair(err, win);
364}
365
366template <typename V>
367auto reduce_max(V vec) -> elem_type_t<V>
368{
369 constexpr int N = vec_size_of(vec);
370
371 auto carry_max = vpmax(vget_high(vec), vget_low(vec));
372
373 for(int k = N / 2; k > 1; k /= 2)
374 {
375 carry_max = vpmax(carry_max, carry_max);
376 }
377
378 return vget_lane<0>(carry_max);
379}
380
381template <typename T>
382void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
383{
384 const auto start_x = in.info()->valid_region().anchor.x();
385 const size_t input_width = in.info()->valid_region().shape.x();
386
387 Iterator input(&in, window);
388 Iterator output(&out, window);
389
390 execute_window_loop(window, [&](const Coordinates &)
391 {
392 // Get pointers
393 const auto in_ptr = reinterpret_cast<const T *>(input.ptr()) + start_x;
394 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
395
396 // Init max value
397 auto vec_max = vdup_n<vec_16_byte_t<T>>(std::numeric_limits<T>::lowest());
398
399 // Loop over input row
400 for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
401 {
402 const auto current_value = vld<vec_16_byte_t<T>>(it);
403 vec_max = vmax(vec_max, current_value);
404 }
405
406 const T max_val = reduce_max(vec_max);
407 *out_ptr = max_val;
408 },
409 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100410}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100411} // namespace
412
413NELogits1DMaxKernel::NELogits1DMaxKernel()
414 : _func(nullptr), _border_size()
415{
416}
417
418BorderSize NELogits1DMaxKernel::border_size() const
419{
420 return _border_size;
421}
422
423void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
424{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000425 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000426 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000427 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000428 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
429 // Configure kernel window
430 auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
431 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100432
433 switch(input->info()->data_type())
434 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000435 case DataType::QASYMM8:
436 _func = &logits_1d_max<qasymm8_t>;
437 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000438#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000439 case DataType::F16:
440 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100441 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000442#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000443 case DataType::F32:
444 _func = &logits_1d_max<float>;
445 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100446 default:
447 ARM_COMPUTE_ERROR("Unsupported data type.");
448 }
449
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000450 _input = input;
451 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100452
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000453 const int input_width = input->info()->valid_region().shape.x();
454 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
455 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
456
457 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
458
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000459 INEKernel::configure(win_config.second);
460}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100461
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000462Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
463{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000464 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
465
466 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
467 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100468
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000469 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100470}
471
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100472void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100473{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100474 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100475 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
476 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
477 ARM_COMPUTE_ERROR_ON(_func == nullptr);
478
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000479 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100480}
481
482namespace
483{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000484Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
485 const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100486{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100487 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000488 // Check input
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000489#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100490 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000491#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100492 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F32);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000493#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100494
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000495 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100496
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000497 // Check max
498 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
499 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000500 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100501
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000502 // Check output if configured
503 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100504 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000505 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
506 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
507 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000508 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100509 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100510
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000511 // Check tmp if configured
512 if(tmp.total_size() != 0)
513 {
514 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
515 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000516 // We could potentially reduce tmp memory if we could predict or make an assumption
517 // on the maximum number of threads that will run in parallel.
518 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
519 }
520
521 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100522}
523
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000524std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
525 ITensorInfo &output, ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100526{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000527 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100528
529 // Output auto initialization if not yet initialized
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000530 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
531 auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100532
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000533 // Tmp auto initialization if not yet initialized
534 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
535 auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
536
537 const int input_width = input.valid_region().shape.x();
538
539 Window win = calculate_max_window(max);
540
541 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
542 AccessWindowHorizontal max_access(&input, 0, 1);
543 AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
544 AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
545
546 const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
547
548 output.set_valid_region(input.valid_region());
549
550 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
551 return std::make_pair(err, win);
552}
553
554template <typename T, int N, int S, int E>
555struct reduce_add_impl
556{
557 template <typename F>
558 static T reduce(F add_fn, vec_n_t<T, N> vec)
559 {
560 constexpr int H = (S + E + 1) / 2;
561 const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
562 const auto reduced_low = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
563 return add_fn(reduced_high, reduced_low);
564 }
565};
566template <typename T, int N, int I>
567struct reduce_add_impl<T, N, I, I>
568{
569 template <typename F>
570 static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
571 {
572 return vget_lane<I>(vec);
573 }
574};
575template <typename V, typename F>
576elem_type_t<V> reduce_add(F add_fn, V vec)
577{
578 constexpr int N = vec_size_of(vec);
579 return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
580}
581
582void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
583{
584 const int start_x = in.info()->valid_region().anchor.x();
585 const int input_width = in.info()->valid_region().shape.x();
586
587 const float scale_beta = -beta * in.info()->quantization_info().scale;
588
589 Iterator in_it(&in, window);
590 Iterator max_it(&max, window);
591 Iterator out_it(&out, window);
592
593 execute_window_loop(window, [&](const Coordinates &)
594 {
595 /* Get pointers */
596 const auto in_ptr = reinterpret_cast<const qasymm8_t *>(in_it.ptr()) + start_x;
597 const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
598 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
599
600 float sum_inversed;
601
602 /* Compute exponentials and sum */
603 {
604 /* Get max value */
605 const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
606 const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
607
608 /* Init sum to zero */
609 auto vec_sum = vdup_n<float32x4x4_t>(0.f);
610
611 /* Loop over row and compute exponentials and sum */
612 int i = 0;
613 constexpr int vec_size = vec_size_of(vec_max);
614 for(; i <= (input_width - vec_size); i += vec_size)
615 {
616 auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
617 vec_elements = vsubq_u8(vec_max, vec_elements);
618
619 auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
620 vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
621
622 vec_sum = vadd(vec_sum, vec_elements_flt);
623
624 vst4q_f32(tmp_ptr + i, vec_elements_flt);
625 }
626 /* Reduce sum */
627 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
628 vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
629 const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
630 float sum = reduce_add(std::plus<float>(), sum_8_byte);
631
632 /* Run remaining elements */
633 for(; i < input_width; ++i)
634 {
635 const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
636 sum += element;
637 tmp_ptr[i] = element;
638 }
639
640 sum_inversed = 256.f / sum;
641 }
642
643 /* Normalize exponentials */
644 {
645 /* Loop over row and compute softmax */
646 int i = 0;
647 {
648 constexpr int vec_size = 16;
649 for(; i <= (input_width - vec_size); i += vec_size)
650 {
651 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
652 auto normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
653 vst(out_ptr + i, normalized_value);
654 }
655 }
656 /* Run remaining elements */
657 for(; i < input_width; ++i)
658 {
659 out_ptr[i] = utility::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
660 }
661 }
662 },
663 in_it, max_it, out_it);
664}
665
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000666template <typename T>
667void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
668 ITensor &out, const float beta, const Window &window)
669{
670 const int start_x = in.info()->valid_region().anchor.x();
671 const int input_width = in.info()->valid_region().shape.x();
672
673 Iterator in_it(&in, window);
674 Iterator max_it(&max, window);
675 Iterator out_it(&out, window);
676
677 execute_window_loop(window, [&](const Coordinates &)
678 {
679 /* Get pointers */
680 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
681 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
682 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
683
684 T sum_inversed;
685
686 /* Compute exponentials and sum */
687 {
688 /* Get max value */
689 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
690 const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
691
692 /* Init sum to zero */
693 auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
694
695 /* Loop over row and compute exponentials and sum */
696 int i = 0;
697 constexpr int vec_size = vec_size_of(vec_sum);
698 for(; i <= (input_width - vec_size); i += vec_size)
699 {
700 auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
701 vec_elements = vsub(vec_elements, vec_max);
702 vec_elements = vexp(vmul_n(vec_elements, beta));
703 vec_sum = vadd(vec_sum, vec_elements);
704 vst(tmp_ptr + i, vec_elements);
705 }
706 /* Reduce sum */
707 const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
708 T sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
709
710 /* Run remaining elements */
711 for(; i < input_width; ++i)
712 {
713 T element = std::exp((in_ptr[i] - max_val) * beta);
714 sum += element;
715 tmp_ptr[i] = element;
716 }
717
718 sum_inversed = T(1) / sum;
719 }
720
721 /* Normalize exponentials */
722 {
723 /* Loop over row and compute softmax */
724 int i = 0;
725 {
726 constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
727 for(; i <= (input_width - vec_size); i += vec_size)
728 {
729 auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
730 vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
731 vst(out_ptr + i, normalized_value);
732 }
733 }
734 /* Run remaining elements */
735 for(; i < input_width; ++i)
736 {
737 out_ptr[i] = tmp_ptr[i] * sum_inversed;
738 }
739 }
740 },
741 in_it, max_it, out_it);
742}
743} // namespace
744
745NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
746 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
747{
748}
749
750void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
751{
752 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
753 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000754 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000755 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info()));
756 // Configure kernel window
757 auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info());
758 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100759
760 switch(input->info()->data_type())
761 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000762 case DataType::QASYMM8:
763 _func = &logits_1d_softmax_qasymm8;
764 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000765#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000766 case DataType::F16:
767 _func = &logits_1d_softmax_float<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100768 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000769#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000770 case DataType::F32:
771 _func = &logits_1d_softmax_float<float>;
772 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100773 default:
774 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100775 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100776 }
777
778 _input = input;
779 _max = max;
780 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100781 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000782 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100783
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000784 INEKernel::configure(win_config.second);
785}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100786
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000787Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
788 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000789{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000790 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
791
792 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp));
793 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100794
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000795 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100796}
797
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000798void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100799{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100800 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100801 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
802 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100803
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000804 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
805 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
806
807 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
808
809 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
810
811 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100812}
813
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000814} // namespace arm_compute