blob: 4144a1877b5b3e9bfc988cf76347e183dedfda02 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Utils.h"
35#include "arm_compute/core/Validate.h"
36#include "arm_compute/core/Window.h"
Georgios Pinitas303f0db2018-11-19 11:56:51 +000037#include "arm_compute/core/utils/misc/SaturateCast.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
39#include <algorithm>
40#include <arm_neon.h>
41#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000042#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000044namespace arm_compute
45{
46template <typename T, int N>
47struct vec_n_type;
48
49#define DECLARE_NEON_VEC_TYPE(T, N, V) \
50 template <> \
51 struct vec_n_type<T, N> \
52 { \
53 using type = V; \
54 };
55
56DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
57DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
58
59DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
60DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
61
62DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
63DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
64
65DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
66DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
67
68DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
69DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
70
71DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
72DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
73
74#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
75DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
76DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
77#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
78
79DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
80DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
81
82template <typename T, int N>
83using vec_n_t = typename vec_n_type<T, N>::type;
84
85template <typename T, int N>
86using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
87
88template <typename T>
89using vec_16_byte_t = vec_n_byte_t<T, 16>;
90
91template <typename T>
92using vec_8_byte_t = vec_n_byte_t<T, 8>;
93
94template <typename T>
95using const_ptr_t = const T *;
96
97template <typename T>
98using ptr_t = T *;
99
100#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
101 template <int lane> \
102 TYPE vget_lane(vec_8_byte_t<TYPE> vec); \
103 template <int lane> \
104 TYPE vget_lane(vec_16_byte_t<TYPE> vec);
105
106FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
107FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
108FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
109FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
110FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
111FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
112#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
113FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
114#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
115FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
116template <int lane>
117float vget_lane(float32x4x4_t vec);
118
119template <typename V>
120using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
121
122template <typename V>
123constexpr size_t vec_size_of(const V &vec)
124{
125 return sizeof(vec) / sizeof(elem_type_t<V>);
126}
127
128template <typename V>
129V vdup_n(elem_type_t<V> val);
130template <typename V>
131V vld(const_ptr_t<elem_type_t<V>> ptr);
132
133#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \
134 template <> \
135 inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val) \
136 { \
137 return vdup_n_##TAG(val); \
138 } \
139 template <> \
140 inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val) \
141 { \
142 return vdupq_n_##TAG(val); \
143 } \
144 template <> \
145 inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
146 { \
147 return vld1_##TAG(ptr); \
148 } \
149 template <> \
150 inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
151 { \
152 return vld1q_##TAG(ptr); \
153 } \
154 inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec) \
155 { \
156 vst1_##TAG(ptr, vec); \
157 } \
158 inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec) \
159 { \
160 vst1q_##TAG(ptr, vec); \
161 } \
162 inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
163 { \
164 return vmaxq_##TAG(a, b); \
165 } \
166 inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
167 { \
168 return vpmax_##TAG(a, b); \
169 } \
170 inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec) \
171 { \
172 return vget_low_##TAG(vec); \
173 } \
174 inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec) \
175 { \
176 return vget_high_##TAG(vec); \
177 } \
178 template <int lane> \
179 inline TYPE vget_lane(vec_8_byte_t<TYPE> vec) \
180 { \
181 static_assert(lane >= 0, "lane is out of bounds"); \
182 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
183 return vget_lane_##TAG(vec, lane); \
184 } \
185 template <int lane> \
186 inline TYPE vget_lane(vec_16_byte_t<TYPE> vec) \
187 { \
188 static_assert(lane >= 0, "lane is out of bounds"); \
189 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
190 return vgetq_lane_##TAG(vec, lane); \
191 }
192
193template <typename T>
194T sqadd(T a, T b);
195template <typename T>
196T sqsub(T a, T b);
197template <typename T>
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100198T sqmul(T a, T b);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000199
200#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \
201 inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
202 { \
203 return vadd_##TAG(a, b); \
204 } \
205 inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
206 { \
207 return vaddq_##TAG(a, b); \
208 } \
209 inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
210 { \
211 return vsubq_##TAG(a, b); \
212 } \
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000213 inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val) \
214 { \
215 return vmulq_n_##TAG(vec, val); \
216 }
217
218DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
219DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
220DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
221DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
222DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
223DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
224#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
225DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
226#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
227DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
228
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000229#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
230DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
231#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
232DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
233
234template <typename VO, typename VI>
235VO vcvt(VI vec);
236
237template <>
238float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
239{
240 const auto low = vmovl_u8(vget_low(vec));
241 const auto high = vmovl_u8(vget_high(vec));
242 float32x4x4_t res = { {
243 vcvtq_f32_u32(vmovl_u16(vget_low(low))),
244 vcvtq_f32_u32(vmovl_u16(vget_high(low))),
245 vcvtq_f32_u32(vmovl_u16(vget_low(high))),
246 vcvtq_f32_u32(vmovl_u16(vget_high(high)))
247 }
248 };
249 return res;
250}
251
252template <>
253uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
254{
255 uint16x8x2_t resU16 = { {
256 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
257 vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
258 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
259 vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
260 }
261 };
262
263 uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
264 return res;
265}
266
267float32x4x4_t vexp(float32x4x4_t vec)
268{
269 float32x4x4_t res = { {
270 vexpq_f32(vec.val[0]),
271 vexpq_f32(vec.val[1]),
272 vexpq_f32(vec.val[2]),
273 vexpq_f32(vec.val[3])
274 }
275 };
276 return res;
277}
278
Georgios Pinitas31fa0d62018-08-23 13:38:59 +0100279float32x4_t vexp(const float32x4_t &vec)
280{
281 return vexpq_f32(vec);
282}
283
284#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
285// TODO (COMPMID-1535) : Revisit FP16 approximations
286float16x8_t vexp(const float16x8_t &vec)
287{
288 float16x4x2_t res =
289 {
290 {
291 vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_low_f16(vec)))),
292 vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_high_f16(vec))))
293 }
294 };
295 return vcombine_f16(res.val[0], res.val[1]);
296}
297#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
298
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000299template <>
300float32x4x4_t vdup_n<float32x4x4_t>(float val)
301{
302 float32x4x4_t res = { {
303 vdupq_n_f32(val),
304 vdupq_n_f32(val),
305 vdupq_n_f32(val),
306 vdupq_n_f32(val)
307 }
308 };
309 return res;
310}
311
312float32x4x4_t vmul_n(float32x4x4_t vec, float val)
313{
314 float32x4x4_t res = { {
315 vmulq_n_f32(vec.val[0], val),
316 vmulq_n_f32(vec.val[1], val),
317 vmulq_n_f32(vec.val[2], val),
318 vmulq_n_f32(vec.val[3], val)
319 }
320 };
321 return res;
322}
323
324float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
325{
326 float32x4x4_t res = { {
327 vaddq_f32(a.val[0], b.val[0]),
328 vaddq_f32(a.val[1], b.val[1]),
329 vaddq_f32(a.val[2], b.val[2]),
330 vaddq_f32(a.val[3], b.val[3])
331 }
332 };
333 return res;
334}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100335
336namespace
337{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000338Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000339{
Anthony Barbiereaefd002018-07-20 17:49:35 +0100340 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100341 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100342
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000343 // Validate in case of configured output
344 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100345 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000346 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000347 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
348 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000350
351 return Status{};
352}
353
354std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
355{
356 // Softmax across the x dimension
357 const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
358 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100359 auto_init_if_empty(output, output_shape, 1, input.data_type(), input.quantization_info());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000360
361 // Configure kernel window
362 const int input_width = input.valid_region().shape.x();
363 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
364 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
365
366 const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
367 output.set_valid_region(out_valid_region);
368
369 Window win = calculate_max_window(output);
370
371 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
372 AccessWindowHorizontal output_access(&output, 0, 1);
373
374 const bool window_changed = update_window_and_padding(win, input_access, output_access);
375
376 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
377 return std::make_pair(err, win);
378}
379
380template <typename V>
381auto reduce_max(V vec) -> elem_type_t<V>
382{
383 constexpr int N = vec_size_of(vec);
384
385 auto carry_max = vpmax(vget_high(vec), vget_low(vec));
386
387 for(int k = N / 2; k > 1; k /= 2)
388 {
389 carry_max = vpmax(carry_max, carry_max);
390 }
391
392 return vget_lane<0>(carry_max);
393}
394
395template <typename T>
396void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
397{
398 const auto start_x = in.info()->valid_region().anchor.x();
399 const size_t input_width = in.info()->valid_region().shape.x();
400
401 Iterator input(&in, window);
402 Iterator output(&out, window);
403
404 execute_window_loop(window, [&](const Coordinates &)
405 {
406 // Get pointers
407 const auto in_ptr = reinterpret_cast<const T *>(input.ptr()) + start_x;
408 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
409
410 // Init max value
Anthony Barbier3a6163e2018-08-10 17:36:36 +0100411 auto vec_max = vdup_n<vec_16_byte_t<T>>(support::cpp11::lowest<T>());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000412
413 // Loop over input row
414 for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
415 {
416 const auto current_value = vld<vec_16_byte_t<T>>(it);
417 vec_max = vmax(vec_max, current_value);
418 }
419
420 const T max_val = reduce_max(vec_max);
421 *out_ptr = max_val;
422 },
423 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100424}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100425} // namespace
426
427NELogits1DMaxKernel::NELogits1DMaxKernel()
428 : _func(nullptr), _border_size()
429{
430}
431
432BorderSize NELogits1DMaxKernel::border_size() const
433{
434 return _border_size;
435}
436
437void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
438{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000439 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000440 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000441 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000442 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
443 // Configure kernel window
444 auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
445 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100446
447 switch(input->info()->data_type())
448 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000449 case DataType::QASYMM8:
450 _func = &logits_1d_max<qasymm8_t>;
451 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000452#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000453 case DataType::F16:
454 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100455 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000456#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000457 case DataType::F32:
458 _func = &logits_1d_max<float>;
459 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100460 default:
461 ARM_COMPUTE_ERROR("Unsupported data type.");
462 }
463
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000464 _input = input;
465 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100466
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000467 const int input_width = input->info()->valid_region().shape.x();
468 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
469 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
470
471 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
472
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000473 INEKernel::configure(win_config.second);
474}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100475
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000476Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
477{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000478 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
479
480 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
481 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100482
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000483 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100484}
485
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100486void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100487{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100488 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100489 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
490 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
491 ARM_COMPUTE_ERROR_ON(_func == nullptr);
492
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000493 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100494}
495
496namespace
497{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000498Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
499 const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100501 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000502 // Check input
Anthony Barbiereaefd002018-07-20 17:49:35 +0100503 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100504 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100505
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000506 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100507
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000508 // Check max
509 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
510 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000511 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100512
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000513 // Check output if configured
514 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100515 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000516 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
517 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
518 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000519 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100520 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100521
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000522 // Check tmp if configured
523 if(tmp.total_size() != 0)
524 {
525 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
526 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000527 // We could potentially reduce tmp memory if we could predict or make an assumption
528 // on the maximum number of threads that will run in parallel.
529 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
530 }
531
532 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100533}
534
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000535std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
536 ITensorInfo &output, ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100537{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000538 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100539
540 // Output auto initialization if not yet initialized
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000541 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
542 auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100543
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000544 // Tmp auto initialization if not yet initialized
545 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
546 auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
547
548 const int input_width = input.valid_region().shape.x();
549
550 Window win = calculate_max_window(max);
551
552 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
553 AccessWindowHorizontal max_access(&input, 0, 1);
554 AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
555 AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
556
557 const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
558
559 output.set_valid_region(input.valid_region());
560
561 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
562 return std::make_pair(err, win);
563}
564
565template <typename T, int N, int S, int E>
566struct reduce_add_impl
567{
568 template <typename F>
569 static T reduce(F add_fn, vec_n_t<T, N> vec)
570 {
571 constexpr int H = (S + E + 1) / 2;
572 const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
573 const auto reduced_low = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
574 return add_fn(reduced_high, reduced_low);
575 }
576};
577template <typename T, int N, int I>
578struct reduce_add_impl<T, N, I, I>
579{
580 template <typename F>
581 static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
582 {
583 return vget_lane<I>(vec);
584 }
585};
586template <typename V, typename F>
587elem_type_t<V> reduce_add(F add_fn, V vec)
588{
589 constexpr int N = vec_size_of(vec);
590 return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
591}
592
593void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
594{
595 const int start_x = in.info()->valid_region().anchor.x();
596 const int input_width = in.info()->valid_region().shape.x();
597
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100598 const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000599
600 Iterator in_it(&in, window);
601 Iterator max_it(&max, window);
602 Iterator out_it(&out, window);
603
604 execute_window_loop(window, [&](const Coordinates &)
605 {
606 /* Get pointers */
607 const auto in_ptr = reinterpret_cast<const qasymm8_t *>(in_it.ptr()) + start_x;
608 const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
609 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
610
611 float sum_inversed;
612
613 /* Compute exponentials and sum */
614 {
615 /* Get max value */
616 const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
617 const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
618
619 /* Init sum to zero */
620 auto vec_sum = vdup_n<float32x4x4_t>(0.f);
621
622 /* Loop over row and compute exponentials and sum */
623 int i = 0;
624 constexpr int vec_size = vec_size_of(vec_max);
625 for(; i <= (input_width - vec_size); i += vec_size)
626 {
627 auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
628 vec_elements = vsubq_u8(vec_max, vec_elements);
629
630 auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
631 vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
632
633 vec_sum = vadd(vec_sum, vec_elements_flt);
634
635 vst4q_f32(tmp_ptr + i, vec_elements_flt);
636 }
637 /* Reduce sum */
638 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
639 vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
640 const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
641 float sum = reduce_add(std::plus<float>(), sum_8_byte);
642
643 /* Run remaining elements */
644 for(; i < input_width; ++i)
645 {
646 const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
647 sum += element;
648 tmp_ptr[i] = element;
649 }
650
651 sum_inversed = 256.f / sum;
652 }
653
654 /* Normalize exponentials */
655 {
656 /* Loop over row and compute softmax */
657 int i = 0;
658 {
659 constexpr int vec_size = 16;
660 for(; i <= (input_width - vec_size); i += vec_size)
661 {
662 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
663 auto normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
664 vst(out_ptr + i, normalized_value);
665 }
666 }
667 /* Run remaining elements */
668 for(; i < input_width; ++i)
669 {
Georgios Pinitas303f0db2018-11-19 11:56:51 +0000670 out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000671 }
672 }
673 },
674 in_it, max_it, out_it);
675}
676
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000677template <typename T>
678void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
679 ITensor &out, const float beta, const Window &window)
680{
681 const int start_x = in.info()->valid_region().anchor.x();
682 const int input_width = in.info()->valid_region().shape.x();
683
684 Iterator in_it(&in, window);
685 Iterator max_it(&max, window);
686 Iterator out_it(&out, window);
687
688 execute_window_loop(window, [&](const Coordinates &)
689 {
690 /* Get pointers */
691 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
692 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
693 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
694
695 T sum_inversed;
696
697 /* Compute exponentials and sum */
698 {
699 /* Get max value */
700 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
701 const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
702
703 /* Init sum to zero */
704 auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
705
706 /* Loop over row and compute exponentials and sum */
707 int i = 0;
708 constexpr int vec_size = vec_size_of(vec_sum);
709 for(; i <= (input_width - vec_size); i += vec_size)
710 {
711 auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
712 vec_elements = vsub(vec_elements, vec_max);
Anthony Barbier3a6163e2018-08-10 17:36:36 +0100713 vec_elements = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000714 vec_sum = vadd(vec_sum, vec_elements);
715 vst(tmp_ptr + i, vec_elements);
716 }
717 /* Reduce sum */
718 const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
719 T sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
720
721 /* Run remaining elements */
722 for(; i < input_width; ++i)
723 {
724 T element = std::exp((in_ptr[i] - max_val) * beta);
725 sum += element;
726 tmp_ptr[i] = element;
727 }
728
729 sum_inversed = T(1) / sum;
730 }
731
732 /* Normalize exponentials */
733 {
734 /* Loop over row and compute softmax */
735 int i = 0;
736 {
737 constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
738 for(; i <= (input_width - vec_size); i += vec_size)
739 {
740 auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
741 vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
742 vst(out_ptr + i, normalized_value);
743 }
744 }
745 /* Run remaining elements */
746 for(; i < input_width; ++i)
747 {
748 out_ptr[i] = tmp_ptr[i] * sum_inversed;
749 }
750 }
751 },
752 in_it, max_it, out_it);
753}
754} // namespace
755
756NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
757 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
758{
759}
760
761void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
762{
763 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
764 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000765 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000766 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info()));
767 // Configure kernel window
768 auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info());
769 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100770
771 switch(input->info()->data_type())
772 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000773 case DataType::QASYMM8:
774 _func = &logits_1d_softmax_qasymm8;
775 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000776#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000777 case DataType::F16:
778 _func = &logits_1d_softmax_float<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100779 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000780#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000781 case DataType::F32:
782 _func = &logits_1d_softmax_float<float>;
783 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100784 default:
785 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100786 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100787 }
788
789 _input = input;
790 _max = max;
791 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100792 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000793 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100794
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000795 INEKernel::configure(win_config.second);
796}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100797
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000798Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
799 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000800{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000801 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
802
803 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp));
804 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100805
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000806 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100807}
808
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000809void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100810{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100811 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100812 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
813 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100814
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000815 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
816 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
817
818 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
819
820 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
821
822 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100823}
824
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000825} // namespace arm_compute