blob: 4041b623b1e402aa1da96cd90e07d39c8a13faf2 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Utils.h"
35#include "arm_compute/core/Validate.h"
36#include "arm_compute/core/Window.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000037#include "arm_compute/core/utils/misc/Utility.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
39#include <algorithm>
40#include <arm_neon.h>
41#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000042#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000044namespace arm_compute
45{
46template <typename T, int N>
47struct vec_n_type;
48
49#define DECLARE_NEON_VEC_TYPE(T, N, V) \
50 template <> \
51 struct vec_n_type<T, N> \
52 { \
53 using type = V; \
54 };
55
56DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
57DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
58
59DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
60DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
61
62DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
63DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
64
65DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
66DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
67
68DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
69DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
70
71DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
72DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
73
74#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
75DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
76DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
77#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
78
79DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
80DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
81
82template <typename T, int N>
83using vec_n_t = typename vec_n_type<T, N>::type;
84
85template <typename T, int N>
86using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
87
88template <typename T>
89using vec_16_byte_t = vec_n_byte_t<T, 16>;
90
91template <typename T>
92using vec_8_byte_t = vec_n_byte_t<T, 8>;
93
94template <typename T>
95using const_ptr_t = const T *;
96
97template <typename T>
98using ptr_t = T *;
99
100#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
101 template <int lane> \
102 TYPE vget_lane(vec_8_byte_t<TYPE> vec); \
103 template <int lane> \
104 TYPE vget_lane(vec_16_byte_t<TYPE> vec);
105
106FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
107FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
108FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
109FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
110FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
111FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
112#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
113FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
114#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
115FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
116template <int lane>
117float vget_lane(float32x4x4_t vec);
118
119template <typename V>
120using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
121
122template <typename V>
123constexpr size_t vec_size_of(const V &vec)
124{
125 return sizeof(vec) / sizeof(elem_type_t<V>);
126}
127
128template <typename V>
129V vdup_n(elem_type_t<V> val);
130template <typename V>
131V vld(const_ptr_t<elem_type_t<V>> ptr);
132
133#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \
134 template <> \
135 inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val) \
136 { \
137 return vdup_n_##TAG(val); \
138 } \
139 template <> \
140 inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val) \
141 { \
142 return vdupq_n_##TAG(val); \
143 } \
144 template <> \
145 inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
146 { \
147 return vld1_##TAG(ptr); \
148 } \
149 template <> \
150 inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
151 { \
152 return vld1q_##TAG(ptr); \
153 } \
154 inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec) \
155 { \
156 vst1_##TAG(ptr, vec); \
157 } \
158 inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec) \
159 { \
160 vst1q_##TAG(ptr, vec); \
161 } \
162 inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
163 { \
164 return vmaxq_##TAG(a, b); \
165 } \
166 inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
167 { \
168 return vpmax_##TAG(a, b); \
169 } \
170 inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec) \
171 { \
172 return vget_low_##TAG(vec); \
173 } \
174 inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec) \
175 { \
176 return vget_high_##TAG(vec); \
177 } \
178 template <int lane> \
179 inline TYPE vget_lane(vec_8_byte_t<TYPE> vec) \
180 { \
181 static_assert(lane >= 0, "lane is out of bounds"); \
182 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
183 return vget_lane_##TAG(vec, lane); \
184 } \
185 template <int lane> \
186 inline TYPE vget_lane(vec_16_byte_t<TYPE> vec) \
187 { \
188 static_assert(lane >= 0, "lane is out of bounds"); \
189 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
190 return vgetq_lane_##TAG(vec, lane); \
191 }
192
193template <typename T>
194T sqadd(T a, T b);
195template <typename T>
196T sqsub(T a, T b);
197template <typename T>
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100198T sqmul(T a, T b);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000199
200#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \
201 inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
202 { \
203 return vadd_##TAG(a, b); \
204 } \
205 inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
206 { \
207 return vaddq_##TAG(a, b); \
208 } \
209 inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
210 { \
211 return vsubq_##TAG(a, b); \
212 } \
213 inline vec_16_byte_t<TYPE> vexp(vec_16_byte_t<TYPE> vec) \
214 { \
215 return vexpq_##TAG(vec); \
216 } \
217 inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val) \
218 { \
219 return vmulq_n_##TAG(vec, val); \
220 }
221
222DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
223DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
224DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
225DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
226DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
227DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
228#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
229DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
230#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
231DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
232
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000233#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
234DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
235#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
236DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
237
238template <typename VO, typename VI>
239VO vcvt(VI vec);
240
241template <>
242float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
243{
244 const auto low = vmovl_u8(vget_low(vec));
245 const auto high = vmovl_u8(vget_high(vec));
246 float32x4x4_t res = { {
247 vcvtq_f32_u32(vmovl_u16(vget_low(low))),
248 vcvtq_f32_u32(vmovl_u16(vget_high(low))),
249 vcvtq_f32_u32(vmovl_u16(vget_low(high))),
250 vcvtq_f32_u32(vmovl_u16(vget_high(high)))
251 }
252 };
253 return res;
254}
255
256template <>
257uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
258{
259 uint16x8x2_t resU16 = { {
260 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
261 vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
262 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
263 vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
264 }
265 };
266
267 uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
268 return res;
269}
270
271float32x4x4_t vexp(float32x4x4_t vec)
272{
273 float32x4x4_t res = { {
274 vexpq_f32(vec.val[0]),
275 vexpq_f32(vec.val[1]),
276 vexpq_f32(vec.val[2]),
277 vexpq_f32(vec.val[3])
278 }
279 };
280 return res;
281}
282
283template <>
284float32x4x4_t vdup_n<float32x4x4_t>(float val)
285{
286 float32x4x4_t res = { {
287 vdupq_n_f32(val),
288 vdupq_n_f32(val),
289 vdupq_n_f32(val),
290 vdupq_n_f32(val)
291 }
292 };
293 return res;
294}
295
296float32x4x4_t vmul_n(float32x4x4_t vec, float val)
297{
298 float32x4x4_t res = { {
299 vmulq_n_f32(vec.val[0], val),
300 vmulq_n_f32(vec.val[1], val),
301 vmulq_n_f32(vec.val[2], val),
302 vmulq_n_f32(vec.val[3], val)
303 }
304 };
305 return res;
306}
307
308float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
309{
310 float32x4x4_t res = { {
311 vaddq_f32(a.val[0], b.val[0]),
312 vaddq_f32(a.val[1], b.val[1]),
313 vaddq_f32(a.val[2], b.val[2]),
314 vaddq_f32(a.val[3], b.val[3])
315 }
316 };
317 return res;
318}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100319
320namespace
321{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000322Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000323{
Anthony Barbiereaefd002018-07-20 17:49:35 +0100324 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100325 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100326
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000327 // Validate in case of configured output
328 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100329 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000330 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000331 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
332 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100333 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000334
335 return Status{};
336}
337
338std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
339{
340 // Softmax across the x dimension
341 const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
342 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100343 auto_init_if_empty(output, output_shape, 1, input.data_type(), input.quantization_info());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000344
345 // Configure kernel window
346 const int input_width = input.valid_region().shape.x();
347 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
348 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
349
350 const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
351 output.set_valid_region(out_valid_region);
352
353 Window win = calculate_max_window(output);
354
355 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
356 AccessWindowHorizontal output_access(&output, 0, 1);
357
358 const bool window_changed = update_window_and_padding(win, input_access, output_access);
359
360 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
361 return std::make_pair(err, win);
362}
363
364template <typename V>
365auto reduce_max(V vec) -> elem_type_t<V>
366{
367 constexpr int N = vec_size_of(vec);
368
369 auto carry_max = vpmax(vget_high(vec), vget_low(vec));
370
371 for(int k = N / 2; k > 1; k /= 2)
372 {
373 carry_max = vpmax(carry_max, carry_max);
374 }
375
376 return vget_lane<0>(carry_max);
377}
378
379template <typename T>
380void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
381{
382 const auto start_x = in.info()->valid_region().anchor.x();
383 const size_t input_width = in.info()->valid_region().shape.x();
384
385 Iterator input(&in, window);
386 Iterator output(&out, window);
387
388 execute_window_loop(window, [&](const Coordinates &)
389 {
390 // Get pointers
391 const auto in_ptr = reinterpret_cast<const T *>(input.ptr()) + start_x;
392 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
393
394 // Init max value
395 auto vec_max = vdup_n<vec_16_byte_t<T>>(std::numeric_limits<T>::lowest());
396
397 // Loop over input row
398 for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
399 {
400 const auto current_value = vld<vec_16_byte_t<T>>(it);
401 vec_max = vmax(vec_max, current_value);
402 }
403
404 const T max_val = reduce_max(vec_max);
405 *out_ptr = max_val;
406 },
407 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100408}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100409} // namespace
410
411NELogits1DMaxKernel::NELogits1DMaxKernel()
412 : _func(nullptr), _border_size()
413{
414}
415
416BorderSize NELogits1DMaxKernel::border_size() const
417{
418 return _border_size;
419}
420
421void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
422{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000423 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000424 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000425 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000426 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
427 // Configure kernel window
428 auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
429 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100430
431 switch(input->info()->data_type())
432 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000433 case DataType::QASYMM8:
434 _func = &logits_1d_max<qasymm8_t>;
435 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000436#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000437 case DataType::F16:
438 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100439 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000440#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000441 case DataType::F32:
442 _func = &logits_1d_max<float>;
443 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100444 default:
445 ARM_COMPUTE_ERROR("Unsupported data type.");
446 }
447
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000448 _input = input;
449 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100450
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000451 const int input_width = input->info()->valid_region().shape.x();
452 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
453 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
454
455 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
456
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000457 INEKernel::configure(win_config.second);
458}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100459
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000460Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
461{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000462 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
463
464 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
465 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100466
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000467 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100468}
469
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100470void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100471{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100472 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100473 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
474 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
475 ARM_COMPUTE_ERROR_ON(_func == nullptr);
476
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000477 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100478}
479
480namespace
481{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000482Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
483 const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100484{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100485 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000486 // Check input
Anthony Barbiereaefd002018-07-20 17:49:35 +0100487 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100488 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100489
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000490 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100491
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000492 // Check max
493 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
494 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000495 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100496
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000497 // Check output if configured
498 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100499 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000500 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
501 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
502 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000503 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100504 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100505
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000506 // Check tmp if configured
507 if(tmp.total_size() != 0)
508 {
509 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
510 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000511 // We could potentially reduce tmp memory if we could predict or make an assumption
512 // on the maximum number of threads that will run in parallel.
513 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
514 }
515
516 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100517}
518
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000519std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
520 ITensorInfo &output, ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100521{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000522 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100523
524 // Output auto initialization if not yet initialized
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000525 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
526 auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100527
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000528 // Tmp auto initialization if not yet initialized
529 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
530 auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
531
532 const int input_width = input.valid_region().shape.x();
533
534 Window win = calculate_max_window(max);
535
536 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
537 AccessWindowHorizontal max_access(&input, 0, 1);
538 AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
539 AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
540
541 const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
542
543 output.set_valid_region(input.valid_region());
544
545 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
546 return std::make_pair(err, win);
547}
548
549template <typename T, int N, int S, int E>
550struct reduce_add_impl
551{
552 template <typename F>
553 static T reduce(F add_fn, vec_n_t<T, N> vec)
554 {
555 constexpr int H = (S + E + 1) / 2;
556 const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
557 const auto reduced_low = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
558 return add_fn(reduced_high, reduced_low);
559 }
560};
561template <typename T, int N, int I>
562struct reduce_add_impl<T, N, I, I>
563{
564 template <typename F>
565 static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
566 {
567 return vget_lane<I>(vec);
568 }
569};
570template <typename V, typename F>
571elem_type_t<V> reduce_add(F add_fn, V vec)
572{
573 constexpr int N = vec_size_of(vec);
574 return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
575}
576
577void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
578{
579 const int start_x = in.info()->valid_region().anchor.x();
580 const int input_width = in.info()->valid_region().shape.x();
581
582 const float scale_beta = -beta * in.info()->quantization_info().scale;
583
584 Iterator in_it(&in, window);
585 Iterator max_it(&max, window);
586 Iterator out_it(&out, window);
587
588 execute_window_loop(window, [&](const Coordinates &)
589 {
590 /* Get pointers */
591 const auto in_ptr = reinterpret_cast<const qasymm8_t *>(in_it.ptr()) + start_x;
592 const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
593 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
594
595 float sum_inversed;
596
597 /* Compute exponentials and sum */
598 {
599 /* Get max value */
600 const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
601 const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
602
603 /* Init sum to zero */
604 auto vec_sum = vdup_n<float32x4x4_t>(0.f);
605
606 /* Loop over row and compute exponentials and sum */
607 int i = 0;
608 constexpr int vec_size = vec_size_of(vec_max);
609 for(; i <= (input_width - vec_size); i += vec_size)
610 {
611 auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
612 vec_elements = vsubq_u8(vec_max, vec_elements);
613
614 auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
615 vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
616
617 vec_sum = vadd(vec_sum, vec_elements_flt);
618
619 vst4q_f32(tmp_ptr + i, vec_elements_flt);
620 }
621 /* Reduce sum */
622 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
623 vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
624 const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
625 float sum = reduce_add(std::plus<float>(), sum_8_byte);
626
627 /* Run remaining elements */
628 for(; i < input_width; ++i)
629 {
630 const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
631 sum += element;
632 tmp_ptr[i] = element;
633 }
634
635 sum_inversed = 256.f / sum;
636 }
637
638 /* Normalize exponentials */
639 {
640 /* Loop over row and compute softmax */
641 int i = 0;
642 {
643 constexpr int vec_size = 16;
644 for(; i <= (input_width - vec_size); i += vec_size)
645 {
646 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
647 auto normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
648 vst(out_ptr + i, normalized_value);
649 }
650 }
651 /* Run remaining elements */
652 for(; i < input_width; ++i)
653 {
654 out_ptr[i] = utility::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
655 }
656 }
657 },
658 in_it, max_it, out_it);
659}
660
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000661template <typename T>
662void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
663 ITensor &out, const float beta, const Window &window)
664{
665 const int start_x = in.info()->valid_region().anchor.x();
666 const int input_width = in.info()->valid_region().shape.x();
667
668 Iterator in_it(&in, window);
669 Iterator max_it(&max, window);
670 Iterator out_it(&out, window);
671
672 execute_window_loop(window, [&](const Coordinates &)
673 {
674 /* Get pointers */
675 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
676 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
677 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
678
679 T sum_inversed;
680
681 /* Compute exponentials and sum */
682 {
683 /* Get max value */
684 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
685 const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
686
687 /* Init sum to zero */
688 auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
689
690 /* Loop over row and compute exponentials and sum */
691 int i = 0;
692 constexpr int vec_size = vec_size_of(vec_sum);
693 for(; i <= (input_width - vec_size); i += vec_size)
694 {
695 auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
696 vec_elements = vsub(vec_elements, vec_max);
697 vec_elements = vexp(vmul_n(vec_elements, beta));
698 vec_sum = vadd(vec_sum, vec_elements);
699 vst(tmp_ptr + i, vec_elements);
700 }
701 /* Reduce sum */
702 const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
703 T sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
704
705 /* Run remaining elements */
706 for(; i < input_width; ++i)
707 {
708 T element = std::exp((in_ptr[i] - max_val) * beta);
709 sum += element;
710 tmp_ptr[i] = element;
711 }
712
713 sum_inversed = T(1) / sum;
714 }
715
716 /* Normalize exponentials */
717 {
718 /* Loop over row and compute softmax */
719 int i = 0;
720 {
721 constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
722 for(; i <= (input_width - vec_size); i += vec_size)
723 {
724 auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
725 vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
726 vst(out_ptr + i, normalized_value);
727 }
728 }
729 /* Run remaining elements */
730 for(; i < input_width; ++i)
731 {
732 out_ptr[i] = tmp_ptr[i] * sum_inversed;
733 }
734 }
735 },
736 in_it, max_it, out_it);
737}
738} // namespace
739
740NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
741 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
742{
743}
744
745void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
746{
747 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
748 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000749 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000750 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info()));
751 // Configure kernel window
752 auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info());
753 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100754
755 switch(input->info()->data_type())
756 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000757 case DataType::QASYMM8:
758 _func = &logits_1d_softmax_qasymm8;
759 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000760#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000761 case DataType::F16:
762 _func = &logits_1d_softmax_float<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100763 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000764#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000765 case DataType::F32:
766 _func = &logits_1d_softmax_float<float>;
767 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100768 default:
769 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100770 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100771 }
772
773 _input = input;
774 _max = max;
775 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100776 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000777 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100778
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000779 INEKernel::configure(win_config.second);
780}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100781
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000782Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
783 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000784{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000785 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
786
787 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp));
788 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100789
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000790 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100791}
792
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000793void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100794{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100795 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100796 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
797 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100798
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000799 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
800 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
801
802 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
803
804 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
805
806 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100807}
808
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000809} // namespace arm_compute