blob: d91efd267f64f656200886dc9e0776e350bdbbf9 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/NEMath.h"
32#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Utils.h"
34#include "arm_compute/core/Validate.h"
35#include "arm_compute/core/Window.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000036#include "arm_compute/core/utils/misc/Utility.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037
38#include <algorithm>
39#include <arm_neon.h>
40#include <cfloat>
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000041#include <functional>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043namespace arm_compute
44{
45template <typename T, int N>
46struct vec_n_type;
47
48#define DECLARE_NEON_VEC_TYPE(T, N, V) \
49 template <> \
50 struct vec_n_type<T, N> \
51 { \
52 using type = V; \
53 };
54
55DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
56DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
57
58DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
59DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
60
61DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
62DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
63
64DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
65DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
66
67DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
68DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
69
70DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
71DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
72
73#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
74DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
75DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
76#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
77
78DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
79DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
80
81template <typename T, int N>
82using vec_n_t = typename vec_n_type<T, N>::type;
83
84template <typename T, int N>
85using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
86
87template <typename T>
88using vec_16_byte_t = vec_n_byte_t<T, 16>;
89
90template <typename T>
91using vec_8_byte_t = vec_n_byte_t<T, 8>;
92
93template <typename T>
94using const_ptr_t = const T *;
95
96template <typename T>
97using ptr_t = T *;
98
99#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
100 template <int lane> \
101 TYPE vget_lane(vec_8_byte_t<TYPE> vec); \
102 template <int lane> \
103 TYPE vget_lane(vec_16_byte_t<TYPE> vec);
104
105FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
106FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
107FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
108FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
109FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
110FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
111#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
112FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
113#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
114FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
115template <int lane>
116float vget_lane(float32x4x4_t vec);
117
118template <typename V>
119using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
120
121template <typename V>
122constexpr size_t vec_size_of(const V &vec)
123{
124 return sizeof(vec) / sizeof(elem_type_t<V>);
125}
126
127template <typename V>
128V vdup_n(elem_type_t<V> val);
129template <typename V>
130V vld(const_ptr_t<elem_type_t<V>> ptr);
131
132#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \
133 template <> \
134 inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val) \
135 { \
136 return vdup_n_##TAG(val); \
137 } \
138 template <> \
139 inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val) \
140 { \
141 return vdupq_n_##TAG(val); \
142 } \
143 template <> \
144 inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
145 { \
146 return vld1_##TAG(ptr); \
147 } \
148 template <> \
149 inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
150 { \
151 return vld1q_##TAG(ptr); \
152 } \
153 inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec) \
154 { \
155 vst1_##TAG(ptr, vec); \
156 } \
157 inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec) \
158 { \
159 vst1q_##TAG(ptr, vec); \
160 } \
161 inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
162 { \
163 return vmaxq_##TAG(a, b); \
164 } \
165 inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
166 { \
167 return vpmax_##TAG(a, b); \
168 } \
169 inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec) \
170 { \
171 return vget_low_##TAG(vec); \
172 } \
173 inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec) \
174 { \
175 return vget_high_##TAG(vec); \
176 } \
177 template <int lane> \
178 inline TYPE vget_lane(vec_8_byte_t<TYPE> vec) \
179 { \
180 static_assert(lane >= 0, "lane is out of bounds"); \
181 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
182 return vget_lane_##TAG(vec, lane); \
183 } \
184 template <int lane> \
185 inline TYPE vget_lane(vec_16_byte_t<TYPE> vec) \
186 { \
187 static_assert(lane >= 0, "lane is out of bounds"); \
188 static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
189 return vgetq_lane_##TAG(vec, lane); \
190 }
191
192template <typename T>
193T sqadd(T a, T b);
194template <typename T>
195T sqsub(T a, T b);
196template <typename T>
197T sqmul(T a, T b, int fixed_point_position);
198
199#define DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(TYPET, TYPEU, TAGT, TAGU) \
200 inline vec_8_byte_t<TYPET> vqsub(vec_8_byte_t<TYPET> a, vec_8_byte_t<TYPET> b) \
201 { \
202 return vqsub_##TAGT(a, b); \
203 } \
204 inline vec_8_byte_t<TYPEU> vqadd(vec_8_byte_t<TYPEU> a, vec_8_byte_t<TYPEU> b) \
205 { \
206 return vqadd_##TAGU(a, b); \
207 } \
208 inline vec_16_byte_t<TYPEU> vqadd(vec_16_byte_t<TYPEU> a, vec_16_byte_t<TYPEU> b) \
209 { \
210 return vqaddq_##TAGU(a, b); \
211 } \
212 inline vec_8_byte_t<TYPET> vqexp(vec_8_byte_t<TYPET> vec, int fixed_point_position) \
213 { \
214 return vqexp_q##TAGT(vec, fixed_point_position); \
215 } \
216 inline auto vmovl(vec_8_byte_t<TYPET> vec)->decltype(vmovl_##TAGT(vec)) \
217 { \
218 return vmovl_##TAGT(vec); \
219 } \
220 inline vec_16_byte_t<TYPET> vqrecip(vec_16_byte_t<TYPET> vec, int fixed_point_position) \
221 { \
222 return vqrecipq_q##TAGT(vec, fixed_point_position); \
223 } \
224 inline vec_16_byte_t<TYPET> vqmul(vec_16_byte_t<TYPET> a, vec_16_byte_t<TYPET> b, int fixed_point_position) \
225 { \
226 return vqmulq_q##TAGT(a, b, fixed_point_position); \
227 } \
228 template <> \
229 inline TYPEU sqadd<TYPEU>(TYPEU a, TYPEU b) \
230 { \
231 return sqadd_q##TAGU(a, b); \
232 } \
233 inline TYPET sqexp(TYPET val, int fixed_point_position) \
234 { \
235 return sqexp_q##TAGT(val, fixed_point_position); \
236 } \
237 template <> \
238 inline TYPET sqsub<TYPET>(TYPET a, TYPET b) \
239 { \
240 return sqsub_q##TAGT(a, b); \
241 } \
242 template <> \
243 inline TYPET sqmul<TYPET>(TYPET a, TYPET b, int fixed_point_position) \
244 { \
245 return sqmul_q##TAGT(a, b, fixed_point_position); \
246 }
247
248#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \
249 inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
250 { \
251 return vadd_##TAG(a, b); \
252 } \
253 inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
254 { \
255 return vaddq_##TAG(a, b); \
256 } \
257 inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
258 { \
259 return vsubq_##TAG(a, b); \
260 } \
261 inline vec_16_byte_t<TYPE> vexp(vec_16_byte_t<TYPE> vec) \
262 { \
263 return vexpq_##TAG(vec); \
264 } \
265 inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val) \
266 { \
267 return vmulq_n_##TAG(vec, val); \
268 }
269
270DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
271DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
272DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
273DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
274DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
275DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
276#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
277DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
278#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
279DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
280
281DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int8_t, int16_t, s8, s16)
282DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int16_t, int32_t, s16, s32)
283
284#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
285DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
286#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
287DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
288
289template <typename VO, typename VI>
290VO vcvt(VI vec);
291
292template <>
293float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
294{
295 const auto low = vmovl_u8(vget_low(vec));
296 const auto high = vmovl_u8(vget_high(vec));
297 float32x4x4_t res = { {
298 vcvtq_f32_u32(vmovl_u16(vget_low(low))),
299 vcvtq_f32_u32(vmovl_u16(vget_high(low))),
300 vcvtq_f32_u32(vmovl_u16(vget_low(high))),
301 vcvtq_f32_u32(vmovl_u16(vget_high(high)))
302 }
303 };
304 return res;
305}
306
307template <>
308uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
309{
310 uint16x8x2_t resU16 = { {
311 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
312 vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
313 vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
314 vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
315 }
316 };
317
318 uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
319 return res;
320}
321
322float32x4x4_t vexp(float32x4x4_t vec)
323{
324 float32x4x4_t res = { {
325 vexpq_f32(vec.val[0]),
326 vexpq_f32(vec.val[1]),
327 vexpq_f32(vec.val[2]),
328 vexpq_f32(vec.val[3])
329 }
330 };
331 return res;
332}
333
334template <>
335float32x4x4_t vdup_n<float32x4x4_t>(float val)
336{
337 float32x4x4_t res = { {
338 vdupq_n_f32(val),
339 vdupq_n_f32(val),
340 vdupq_n_f32(val),
341 vdupq_n_f32(val)
342 }
343 };
344 return res;
345}
346
347float32x4x4_t vmul_n(float32x4x4_t vec, float val)
348{
349 float32x4x4_t res = { {
350 vmulq_n_f32(vec.val[0], val),
351 vmulq_n_f32(vec.val[1], val),
352 vmulq_n_f32(vec.val[2], val),
353 vmulq_n_f32(vec.val[3], val)
354 }
355 };
356 return res;
357}
358
359float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
360{
361 float32x4x4_t res = { {
362 vaddq_f32(a.val[0], b.val[0]),
363 vaddq_f32(a.val[1], b.val[1]),
364 vaddq_f32(a.val[2], b.val[2]),
365 vaddq_f32(a.val[3], b.val[3])
366 }
367 };
368 return res;
369}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100370
371namespace
372{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000373Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000374{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000375#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000376 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
377#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
378 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000379#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100380
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000381 // Validate in case of configured output
382 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100383 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000384 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
385 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
386 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
387 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100388 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000389
390 return Status{};
391}
392
393std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
394{
395 // Softmax across the x dimension
396 const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
397 // Output auto initialization if not yet initialized
398 auto_init_if_empty(output, output_shape, 1, input.data_type(), input.fixed_point_position(), input.quantization_info());
399
400 // Configure kernel window
401 const int input_width = input.valid_region().shape.x();
402 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
403 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
404
405 const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
406 output.set_valid_region(out_valid_region);
407
408 Window win = calculate_max_window(output);
409
410 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
411 AccessWindowHorizontal output_access(&output, 0, 1);
412
413 const bool window_changed = update_window_and_padding(win, input_access, output_access);
414
415 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
416 return std::make_pair(err, win);
417}
418
419template <typename V>
420auto reduce_max(V vec) -> elem_type_t<V>
421{
422 constexpr int N = vec_size_of(vec);
423
424 auto carry_max = vpmax(vget_high(vec), vget_low(vec));
425
426 for(int k = N / 2; k > 1; k /= 2)
427 {
428 carry_max = vpmax(carry_max, carry_max);
429 }
430
431 return vget_lane<0>(carry_max);
432}
433
434template <typename T>
435void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
436{
437 const auto start_x = in.info()->valid_region().anchor.x();
438 const size_t input_width = in.info()->valid_region().shape.x();
439
440 Iterator input(&in, window);
441 Iterator output(&out, window);
442
443 execute_window_loop(window, [&](const Coordinates &)
444 {
445 // Get pointers
446 const auto in_ptr = reinterpret_cast<const T *>(input.ptr()) + start_x;
447 const auto out_ptr = reinterpret_cast<T *>(output.ptr());
448
449 // Init max value
450 auto vec_max = vdup_n<vec_16_byte_t<T>>(std::numeric_limits<T>::lowest());
451
452 // Loop over input row
453 for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
454 {
455 const auto current_value = vld<vec_16_byte_t<T>>(it);
456 vec_max = vmax(vec_max, current_value);
457 }
458
459 const T max_val = reduce_max(vec_max);
460 *out_ptr = max_val;
461 },
462 input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100463}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100464} // namespace
465
466NELogits1DMaxKernel::NELogits1DMaxKernel()
467 : _func(nullptr), _border_size()
468{
469}
470
471BorderSize NELogits1DMaxKernel::border_size() const
472{
473 return _border_size;
474}
475
476void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
477{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000478 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000479 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000480 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000481 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
482 // Configure kernel window
483 auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
484 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100485
486 switch(input->info()->data_type())
487 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000488 case DataType::QASYMM8:
489 _func = &logits_1d_max<qasymm8_t>;
490 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100491 case DataType::QS8:
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000492 _func = &logits_1d_max<qint8_t>;
Georgios Pinitas9247c922017-06-28 18:29:47 +0100493 break;
494 case DataType::QS16:
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000495 _func = &logits_1d_max<qint16_t>;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100496 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000497#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000498 case DataType::F16:
499 _func = &logits_1d_max<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100500 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000501#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000502 case DataType::F32:
503 _func = &logits_1d_max<float>;
504 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100505 default:
506 ARM_COMPUTE_ERROR("Unsupported data type.");
507 }
508
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000509 _input = input;
510 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100511
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000512 const int input_width = input->info()->valid_region().shape.x();
513 const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
514 const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
515
516 _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
517
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000518 INEKernel::configure(win_config.second);
519}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100520
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000521Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
522{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000523 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
524
525 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
526 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100527
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000528 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100529}
530
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100531void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100533 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100534 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
535 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
536 ARM_COMPUTE_ERROR_ON(_func == nullptr);
537
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000538 (*_func)(*_input, *_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539}
540
541namespace
542{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000543Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
544 const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100545{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000546 // Check input
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000547#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000548 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
549#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
550 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000551#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100552
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000553 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100554
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000555 // Check max
556 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
557 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
558 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &max);
559 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100560
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000561 // Check output if configured
562 if(output.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100563 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000564 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
565 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
566 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
567 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
568 ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100569 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100570
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000571 // Check beta
572 ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input.data_type()));
573
574 // Check tmp if configured
575 if(tmp.total_size() != 0)
576 {
577 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
578 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
579 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &tmp);
580 // We could potentially reduce tmp memory if we could predict or make an assumption
581 // on the maximum number of threads that will run in parallel.
582 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
583 }
584
585 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100586}
587
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000588std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
589 ITensorInfo &output, ITensorInfo &tmp)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100590{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000591 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100592
593 // Output auto initialization if not yet initialized
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000594 const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
595 auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
Georgios Pinitasd368df32017-07-04 11:06:15 +0100596
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000597 // Tmp auto initialization if not yet initialized
598 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
599 auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
600
601 const int input_width = input.valid_region().shape.x();
602
603 Window win = calculate_max_window(max);
604
605 AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
606 AccessWindowHorizontal max_access(&input, 0, 1);
607 AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
608 AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
609
610 const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
611
612 output.set_valid_region(input.valid_region());
613
614 const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
615 return std::make_pair(err, win);
616}
617
618template <typename T, int N, int S, int E>
619struct reduce_add_impl
620{
621 template <typename F>
622 static T reduce(F add_fn, vec_n_t<T, N> vec)
623 {
624 constexpr int H = (S + E + 1) / 2;
625 const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
626 const auto reduced_low = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
627 return add_fn(reduced_high, reduced_low);
628 }
629};
630template <typename T, int N, int I>
631struct reduce_add_impl<T, N, I, I>
632{
633 template <typename F>
634 static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
635 {
636 return vget_lane<I>(vec);
637 }
638};
639template <typename V, typename F>
640elem_type_t<V> reduce_add(F add_fn, V vec)
641{
642 constexpr int N = vec_size_of(vec);
643 return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
644}
645
646void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
647{
648 const int start_x = in.info()->valid_region().anchor.x();
649 const int input_width = in.info()->valid_region().shape.x();
650
651 const float scale_beta = -beta * in.info()->quantization_info().scale;
652
653 Iterator in_it(&in, window);
654 Iterator max_it(&max, window);
655 Iterator out_it(&out, window);
656
657 execute_window_loop(window, [&](const Coordinates &)
658 {
659 /* Get pointers */
660 const auto in_ptr = reinterpret_cast<const qasymm8_t *>(in_it.ptr()) + start_x;
661 const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
662 const auto tmp_ptr = reinterpret_cast<float *>(tmp);
663
664 float sum_inversed;
665
666 /* Compute exponentials and sum */
667 {
668 /* Get max value */
669 const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
670 const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
671
672 /* Init sum to zero */
673 auto vec_sum = vdup_n<float32x4x4_t>(0.f);
674
675 /* Loop over row and compute exponentials and sum */
676 int i = 0;
677 constexpr int vec_size = vec_size_of(vec_max);
678 for(; i <= (input_width - vec_size); i += vec_size)
679 {
680 auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
681 vec_elements = vsubq_u8(vec_max, vec_elements);
682
683 auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
684 vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
685
686 vec_sum = vadd(vec_sum, vec_elements_flt);
687
688 vst4q_f32(tmp_ptr + i, vec_elements_flt);
689 }
690 /* Reduce sum */
691 const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
692 vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
693 const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
694 float sum = reduce_add(std::plus<float>(), sum_8_byte);
695
696 /* Run remaining elements */
697 for(; i < input_width; ++i)
698 {
699 const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
700 sum += element;
701 tmp_ptr[i] = element;
702 }
703
704 sum_inversed = 256.f / sum;
705 }
706
707 /* Normalize exponentials */
708 {
709 /* Loop over row and compute softmax */
710 int i = 0;
711 {
712 constexpr int vec_size = 16;
713 for(; i <= (input_width - vec_size); i += vec_size)
714 {
715 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
716 auto normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
717 vst(out_ptr + i, normalized_value);
718 }
719 }
720 /* Run remaining elements */
721 for(; i < input_width; ++i)
722 {
723 out_ptr[i] = utility::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
724 }
725 }
726 },
727 in_it, max_it, out_it);
728}
729
730template <typename T, typename U>
731void logits_1d_softmax_fixed_point(const ITensor &in, const ITensor &max, void *const tmp,
732 ITensor &out, const float /*beta*/, const Window &window)
733{
734 const int start_x = in.info()->valid_region().anchor.x();
735 const int input_width = in.info()->valid_region().shape.x();
736
737 const int fixed_point_position = in.info()->fixed_point_position();
738
739 Iterator in_it(&in, window);
740 Iterator max_it(&max, window);
741 Iterator out_it(&out, window);
742
743 execute_window_loop(window, [&](const Coordinates &)
744 {
745 /* Get pointers */
746 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
747 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
748 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
749
750 vec_16_byte_t<T> vec_sum_inversed;
751
752 /* Compute exponentials and sum */
753 {
754 /* Get max value */
755 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
756 const auto vec_max = vdup_n<vec_8_byte_t<T>>(max_val);
757
758 /* Init sum to zero */
759 auto vec_sum = vdup_n<vec_16_byte_t<U>>(0);
760
761 /* Loop over row and compute exponentials and sum */
762 int i = 0;
763 constexpr int vec_size = vec_size_of(vec_sum);
764 for(; i <= (input_width - vec_size); i += vec_size)
765 {
766 auto vec_elements = vld<vec_8_byte_t<T>>(in_ptr + i);
767 vec_elements = vqsub(vec_elements, vec_max);
768 vec_elements = vqexp(vec_elements, fixed_point_position);
769 vec_sum = vqadd(vec_sum, vmovl(vec_elements));
770 vst(tmp_ptr + i, vec_elements);
771 }
772 /* Reduce sum */
773 const vec_8_byte_t<U> sum_8_byte = vqadd(vget_high(vec_sum), vget_low(vec_sum));
774 U sum = reduce_add(sqadd<U>, sum_8_byte);
775
776 /* Run remaining elements */
777 for(; i < input_width; ++i)
778 {
779 T element = sqexp(sqsub(in_ptr[i], max_val), fixed_point_position);
780 sum = sqadd<U>(sum, element);
781 tmp_ptr[i] = element;
782 }
783
784 const auto qsum = utility::saturate_cast<T>(sum);
785 vec_sum_inversed = vqrecip(vdup_n<vec_16_byte_t<T>>(qsum), fixed_point_position);
786 }
787
788 /* Normalize exponentials */
789 {
790 /* Loop over row and compute softmax */
791 int i = 0;
792 constexpr int vec_size = vec_size_of(vec_sum_inversed);
793 for(; i <= (input_width - vec_size); i += vec_size)
794 {
795 const auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
796 const vec_16_byte_t<T> normalized_value = vqmul(vec_in, vec_sum_inversed, fixed_point_position);
797 vst(out_ptr + i, normalized_value);
798 }
799
800 const T sum_inversed = vget_lane<0>(vec_sum_inversed);
801
802 /* Run remaining elements */
803 for(; i < input_width; ++i)
804 {
805 out_ptr[i] = sqmul(tmp_ptr[i], sum_inversed, fixed_point_position);
806 }
807 }
808 },
809 in_it, max_it, out_it);
810}
811
812template <typename T>
813void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
814 ITensor &out, const float beta, const Window &window)
815{
816 const int start_x = in.info()->valid_region().anchor.x();
817 const int input_width = in.info()->valid_region().shape.x();
818
819 Iterator in_it(&in, window);
820 Iterator max_it(&max, window);
821 Iterator out_it(&out, window);
822
823 execute_window_loop(window, [&](const Coordinates &)
824 {
825 /* Get pointers */
826 const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
827 const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
828 const auto tmp_ptr = reinterpret_cast<T *>(tmp);
829
830 T sum_inversed;
831
832 /* Compute exponentials and sum */
833 {
834 /* Get max value */
835 const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
836 const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
837
838 /* Init sum to zero */
839 auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
840
841 /* Loop over row and compute exponentials and sum */
842 int i = 0;
843 constexpr int vec_size = vec_size_of(vec_sum);
844 for(; i <= (input_width - vec_size); i += vec_size)
845 {
846 auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
847 vec_elements = vsub(vec_elements, vec_max);
848 vec_elements = vexp(vmul_n(vec_elements, beta));
849 vec_sum = vadd(vec_sum, vec_elements);
850 vst(tmp_ptr + i, vec_elements);
851 }
852 /* Reduce sum */
853 const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
854 T sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
855
856 /* Run remaining elements */
857 for(; i < input_width; ++i)
858 {
859 T element = std::exp((in_ptr[i] - max_val) * beta);
860 sum += element;
861 tmp_ptr[i] = element;
862 }
863
864 sum_inversed = T(1) / sum;
865 }
866
867 /* Normalize exponentials */
868 {
869 /* Loop over row and compute softmax */
870 int i = 0;
871 {
872 constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
873 for(; i <= (input_width - vec_size); i += vec_size)
874 {
875 auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
876 vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
877 vst(out_ptr + i, normalized_value);
878 }
879 }
880 /* Run remaining elements */
881 for(; i < input_width; ++i)
882 {
883 out_ptr[i] = tmp_ptr[i] * sum_inversed;
884 }
885 }
886 },
887 in_it, max_it, out_it);
888}
889} // namespace
890
891NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
892 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
893{
894}
895
896void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
897{
898 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
899 ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000900 // Perform validation step
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000901 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info()));
902 // Configure kernel window
903 auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info());
904 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100905
906 switch(input->info()->data_type())
907 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000908 case DataType::QASYMM8:
909 _func = &logits_1d_softmax_qasymm8;
910 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100911 case DataType::QS8:
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000912 _func = &logits_1d_softmax_fixed_point<qint8_t, qint16_t>;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100913 break;
Georgios Pinitas9247c922017-06-28 18:29:47 +0100914 case DataType::QS16:
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000915 _func = &logits_1d_softmax_fixed_point<qint16_t, qint32_t>;
Georgios Pinitas9247c922017-06-28 18:29:47 +0100916 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000917#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000918 case DataType::F16:
919 _func = &logits_1d_softmax_float<float16_t>;
Pablo Tellob49a7152017-07-11 16:31:35 +0100920 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000921#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000922 case DataType::F32:
923 _func = &logits_1d_softmax_float<float>;
924 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100925 default:
926 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100927 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100928 }
929
930 _input = input;
931 _max = max;
932 _output = output;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100933 _beta = beta;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000934 _tmp = tmp;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100935
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000936 INEKernel::configure(win_config.second);
937}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100938
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000939Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
940 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000941{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000942 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
943
944 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp));
945 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100946
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000947 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100948}
949
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000950void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100951{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100952 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100953 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
954 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100955
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000956 const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
957 const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
958
959 ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
960
961 void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
962
963 (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100964}
965
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000966} // namespace arm_compute