blob: f1027590e43711eff044bbea783722a10e1ab84b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/NEMath.h"
32#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Utils.h"
34#include "arm_compute/core/Validate.h"
35#include "arm_compute/core/Window.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39#include <cfloat>
40
41using namespace arm_compute;
42
43namespace
44{
Georgios Pinitas9247c922017-06-28 18:29:47 +010045void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
46{
47 Window in_slice = window.first_slice_window_1D();
48
49 Window window_max(window);
50 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
51 Window max_slice = window_max.first_slice_window_1D();
52
53 do
54 {
55 Iterator input(in, in_slice);
56 Iterator output(out, max_slice);
57
58 qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
59
60 execute_window_loop(in_slice, [&](const Coordinates & id)
61 {
62 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
63 const qint8x16_t current_value = vld1q_qs8(in_ptr);
64 vec_max = vmaxq_qs8(vec_max, current_value);
65 },
66 input);
67
68 qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
69 carry_max = vpmax_qs8(carry_max, carry_max);
70 carry_max = vpmax_qs8(carry_max, carry_max);
71 carry_max = vpmax_qs8(carry_max, carry_max);
72
73 *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
74 }
75 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
76}
77void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
78{
79 Window in_slice = window.first_slice_window_1D();
80
81 Window window_max(window);
82 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
83 Window max_slice = window_max.first_slice_window_1D();
84
85 do
86 {
87 Iterator input(in, in_slice);
88 Iterator output(out, max_slice);
89
90 qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
91
92 execute_window_loop(in_slice, [&](const Coordinates & id)
93 {
94 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
95 const qint16x8_t current_value = vld1q_qs16(in_ptr);
96 vec_max = vmaxq_qs16(vec_max, current_value);
97 },
98 input);
99
100 qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
101 carry_max = vpmax_qs16(carry_max, carry_max);
102 carry_max = vpmax_qs16(carry_max, carry_max);
103
104 *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
105 }
106 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
107}
Pablo Tellob49a7152017-07-11 16:31:35 +0100108
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000109#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100110void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
111{
112 Window in_slice = window.first_slice_window_1D();
113
114 Window window_max(window);
115 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
116 Window max_slice = window_max.first_slice_window_1D();
117
118 do
119 {
120 Iterator input(in, in_slice);
121 Iterator output(out, max_slice);
122
123 float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
124
125 execute_window_loop(in_slice, [&](const Coordinates & id)
126 {
127 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
128 const float16x8_t current_value = vld1q_f16(in_ptr);
129 vec_max = vmaxq_f16(vec_max, current_value);
130 },
131 input);
132
133 float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
134 carry_max = vpmax_f16(carry_max, carry_max);
135 carry_max = vpmax_f16(carry_max, carry_max);
136
137 *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
138 }
139 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
140}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000141#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100142
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100143void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
144{
145 Window in_slice = window.first_slice_window_1D();
146
147 Window window_max(window);
148 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
149 Window max_slice = window_max.first_slice_window_1D();
150
151 do
152 {
153 Iterator input(in, in_slice);
154 Iterator output(out, max_slice);
155
156 float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
157
158 execute_window_loop(in_slice, [&](const Coordinates & id)
159 {
160 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
161 const float32x4_t current_value = vld1q_f32(in_ptr);
162 vec_max = vmaxq_f32(vec_max, current_value);
163 },
164 input);
165
166 float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
167 carry_max = vpmax_f32(carry_max, carry_max);
168
169 *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
170 }
171 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
172}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100173} // namespace
174
175NELogits1DMaxKernel::NELogits1DMaxKernel()
176 : _func(nullptr), _border_size()
177{
178}
179
180BorderSize NELogits1DMaxKernel::border_size() const
181{
182 return _border_size;
183}
184
185void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
186{
Pablo Tellob49a7152017-07-11 16:31:35 +0100187 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100188 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
189
190 // Softmax across the x dimension
191 TensorShape output_shape{ input->info()->tensor_shape() };
192 output_shape.set(0, 1);
193
194 // Output auto initialization if not yet initialized
195 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
196
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100198 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
199 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100200
201 const int input_width = input->info()->valid_region().shape.x();
Georgios Pinitas9247c922017-06-28 18:29:47 +0100202 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100203
204 switch(input->info()->data_type())
205 {
206 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100207 _func = &logits_1d_max_qs8;
208 break;
209 case DataType::QS16:
210 _func = &logits_1d_max_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100211 break;
212 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100213 _func = &logits_1d_max_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100214 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100215 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000216#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100217 _func = &logits_1d_max_f16;
218 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000219#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100220 default:
221 ARM_COMPUTE_ERROR("Unsupported data type.");
222 }
223
224 _input = input;
225 _output = output;
Giorgio Arenaa2611812017-07-21 10:08:48 +0100226 _border_size = BorderSize(0, num_elems_processed_per_iteration - (input_width % num_elems_processed_per_iteration), 0, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100227
228 // Configure kernel window
229 constexpr unsigned int num_elems_written_per_row = 1;
230
231 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
232 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
233 AccessWindowHorizontal output_access(output->info(), 0, num_elems_written_per_row, 1.f / input_width);
234
235 update_window_and_padding(win, input_access, output_access);
236
237 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
238
239 INEKernel::configure(win);
240}
241
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100242void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100244 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
246 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
247 ARM_COMPUTE_ERROR_ON(_func == nullptr);
248
249 (*_func)(_input, _output, window);
250}
251
252namespace
253{
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100254void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
255{
256 Window window_max(window);
257 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
258
259 Window max_slice = window_max.first_slice_window_1D();
260 Window in_slice = window.first_slice_window_1D();
261
262 constexpr int step = 8;
263 const int long_steps = in->info()->valid_region().shape.x() / step;
264 const int small_steps = in->info()->valid_region().shape.x() % step;
265 const int fixed_point_position = in->info()->fixed_point_position();
266
267 do
268 {
269 Iterator input(in, in_slice);
270 Iterator exp(out, in_slice);
271 Iterator _max(max, max_slice);
272 Iterator _sum(sum, max_slice);
273
274 // Get pointers
275 auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
276 auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
277
278 // Init sum to zero
279 qint16x8_t vec_sum_value = vdupq_n_qs16(0);
280
281 // Get max value
282 const auto max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
283 const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
284
285 // Run neon loop
286 for(int i = 0; i < long_steps; ++i)
287 {
288 qint8x8_t vec_elements = vld1_qs8(in_ptr);
289 vec_elements = vqsub_qs8(vec_elements, vec_max);
290 vec_elements = vqexp_qs8(vec_elements, fixed_point_position);
291
292 vst1_qs8(exp_ptr, vec_elements);
293 vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
294
295 in_ptr += step;
296 exp_ptr += step;
297 }
298 // Reduce sum
299 const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
300 const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
301 const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
302 qint16_t sum = sqadd_qs16(sum0, sum1);
303
304 // Run remaining elements
305 for(int i = 0; i < small_steps; ++i)
306 {
307 qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
308 exp_ptr[i] = element;
309 sum = sqadd_qs16(sum, element);
310 }
311
312 *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
313 }
314 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
315}
Georgios Pinitas9247c922017-06-28 18:29:47 +0100316void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
317{
318 Window window_max(window);
319 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
320
321 Window max_slice = window_max.first_slice_window_1D();
322 Window in_slice = window.first_slice_window_1D();
323
324 constexpr int step = 4;
325 const int long_steps = in->info()->valid_region().shape.x() / step;
326 const int small_steps = in->info()->valid_region().shape.x() % step;
327 const int fixed_point_position = in->info()->fixed_point_position();
328
329 do
330 {
331 Iterator input(in, in_slice);
332 Iterator exp(out, in_slice);
333 Iterator _max(max, max_slice);
334 Iterator _sum(sum, max_slice);
335
336 // Get pointers
337 auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
338 auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
339
340 // Init sum to zero
341 qint32x4_t vec_sum_value = vdupq_n_qs32(0);
342
343 // Get max value
344 const auto max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
345 const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
346
347 // Run neon loop
348 for(int i = 0; i < long_steps; ++i)
349 {
350 qint16x4_t vec_elements = vld1_qs16(in_ptr);
351 vec_elements = vqsub_qs16(vec_elements, vec_max);
352 vec_elements = vqexp_qs16(vec_elements, fixed_point_position);
353
354 vst1_qs16(exp_ptr, vec_elements);
355 vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
356
357 in_ptr += step;
358 exp_ptr += step;
359 }
360 // Reduce sum
361 qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
362 qint32_t sum = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
363
364 // Run remaining elements
365 for(int i = 0; i < small_steps; ++i)
366 {
367 qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
368 exp_ptr[i] = element;
369 sum = sqadd_qs32(sum, element);
370 }
371
372 *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
373 }
374 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
375}
Pablo Tellob49a7152017-07-11 16:31:35 +0100376
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000377#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100378void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
379{
380 Window window_max(window);
381 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
382
383 Window max_slice = window_max.first_slice_window_1D();
384 Window in_slice = window.first_slice_window_1D();
385
386 constexpr int step = 8;
387 const int long_steps = in->info()->valid_region().shape.x() / step;
388 const int small_steps = in->info()->valid_region().shape.x() % step;
389
390 do
391 {
392 Iterator input(in, in_slice);
393 Iterator exp(out, in_slice);
394 Iterator _max(max, max_slice);
395 Iterator _sum(sum, max_slice);
396
397 // Get pointers
398 auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
399 auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
400
401 // Init sum to zero
402 float16x8_t vec_sum_value = vdupq_n_f16(0);
403
404 // Get max value
405 const auto max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
406 const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
407
408 // Run neon loop
409 for(int i = 0; i < long_steps; ++i)
410 {
411 float16x8_t vec_elements = vld1q_f16(in_ptr);
412 vec_elements = vsubq_f16(vec_elements, vec_max);
413 vec_elements = vexpq_f16(vec_elements);
414
415 vst1q_f16(exp_ptr, vec_elements);
416 vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
417
418 in_ptr += step;
419 exp_ptr += step;
420 }
421 // Reduce sum
422 const float16x4_t sum_red = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
423 const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
424 float16_t sum = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
425
426 // Run remaining elements
427 for(int i = 0; i < small_steps; ++i)
428 {
429 const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr));
430 exp_ptr[i] = element;
431 sum += element;
432 }
433 *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
434 }
435 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
436}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000437#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100438
Georgios Pinitas9247c922017-06-28 18:29:47 +0100439void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
440{
441 Window window_max(window);
442 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
443
444 Window max_slice = window_max.first_slice_window_1D();
445 Window in_slice = window.first_slice_window_1D();
446
447 constexpr int step = 4;
448 const int long_steps = in->info()->valid_region().shape.x() / step;
449 const int small_steps = in->info()->valid_region().shape.x() % step;
450
451 do
452 {
453 Iterator input(in, in_slice);
454 Iterator exp(out, in_slice);
455 Iterator _max(max, max_slice);
456 Iterator _sum(sum, max_slice);
457
458 // Get pointers
459 auto in_ptr = reinterpret_cast<const float *>(input.ptr());
460 auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
461
462 // Init sum to zero
463 float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
464
465 // Get max value
466 const auto max_ptr = reinterpret_cast<const float *>(_max.ptr());
467 const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
468
469 // Run neon loop
470 for(int i = 0; i < long_steps; ++i)
471 {
472 float32x4_t vec_elements = vld1q_f32(in_ptr);
473 vec_elements = vsubq_f32(vec_elements, vec_max);
474 vec_elements = vexpq_f32(vec_elements);
475
476 vst1q_f32(exp_ptr, vec_elements);
477 vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
478
479 in_ptr += step;
480 exp_ptr += step;
481 }
482
483 // Reduce sum
484 float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
485 carry_addition = vpadd_f32(carry_addition, carry_addition);
486 float sum = vget_lane_f32(carry_addition, 0);
487
488 // Run remaining elements
489 for(int i = 0; i < small_steps; ++i)
490 {
491 float element = std::exp(in_ptr[i] - *max_ptr);
492 exp_ptr[i] = element;
493 sum += element;
494 }
495
496 *(reinterpret_cast<float *>(_sum.ptr())) = sum;
497 }
498 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
499}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500} //namespace
501
502NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
503 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr)
504{
505}
506
507void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum)
508{
Pablo Tellob49a7152017-07-11 16:31:35 +0100509 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100510 ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
511
512 // Output auto initialization if not yet initialized
513 auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
514 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
515
516 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum);
517 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum);
518 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100519 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum);
520
521 unsigned int num_elems_processed_per_iteration = input->info()->valid_region().shape.x();
522
523 switch(input->info()->data_type())
524 {
525 case DataType::QS8:
526 _func = &logits_1d_shift_exp_sum_qs8;
527 break;
Georgios Pinitas9247c922017-06-28 18:29:47 +0100528 case DataType::QS16:
529 _func = &logits_1d_shift_exp_sum_qs16;
530 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100531 case DataType::F32:
532 _func = &logits_1d_shift_exp_sum_f32;
533 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100534 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000535#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100536 _func = &logits_1d_shift_exp_sum_f16;
537 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000538#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539 default:
540 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100541 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100542 }
543
544 _input = input;
545 _max = max;
546 _output = output;
547 _sum = sum;
548
549 // Configure kernel window
550 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
551 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
552 AccessWindowHorizontal max_access(max->info(), 0, 1);
553 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
554 AccessWindowHorizontal sum_access(sum->info(), 0, 1);
555
556 update_window_and_padding(win, input_access, max_access, output_access, sum_access);
557
558 output_access.set_valid_region(win, input->info()->valid_region());
559 sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->info()->tensor_shape()));
560
561 INEKernel::configure(win);
562}
563
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100564void NELogits1DShiftExpSumKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100565{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100566 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100567 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
568 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
569 ARM_COMPUTE_ERROR_ON(_func == nullptr);
570
571 (*_func)(_input, _max, _output, _sum, window);
572}
573
574namespace
575{
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100576void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
577{
578 Window window_sum(window);
579 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
580 Window sum_slice = window_sum.first_slice_window_1D();
581 Window in_slice = window.first_slice_window_1D();
582
583 const int fixed_point_position = in->info()->fixed_point_position();
584
585 do
586 {
587 Iterator input(in, in_slice);
588 Iterator _sum(sum, sum_slice);
589 Iterator output(out, in_slice);
590
591 const int8_t sum_value = *reinterpret_cast<const qint8_t *>(_sum.ptr());
592 const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
593
594 execute_window_loop(in_slice, [&](const Coordinates & id)
595 {
596 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
597 const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
598
599 const qint8x16_t vec_in = vld1q_qs8(in_ptr);
600 const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
601
602 vst1q_qs8(out_ptr, normalized_value);
603 },
604 input, output);
605 }
606 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
607}
Georgios Pinitas9247c922017-06-28 18:29:47 +0100608void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
609{
610 Window window_sum(window);
611 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
612 Window sum_slice = window_sum.first_slice_window_1D();
613 Window in_slice = window.first_slice_window_1D();
614
615 const int fixed_point_position = in->info()->fixed_point_position();
616
617 do
618 {
619 Iterator input(in, in_slice);
620 Iterator _sum(sum, sum_slice);
621 Iterator output(out, in_slice);
622
623 const int16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
624 const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
625
626 execute_window_loop(in_slice, [&](const Coordinates & id)
627 {
628 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
629 const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
630
631 const qint16x8_t vec_in = vld1q_qs16(in_ptr);
632 const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
633
634 vst1q_qs16(out_ptr, normalized_value);
635 },
636 input, output);
637 }
638 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
639}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000640#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100641void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
642{
643 Window window_sum(window);
644 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
645 Window sum_slice = window_sum.first_slice_window_1D();
646 Window in_slice = window.first_slice_window_1D();
647
648 do
649 {
650 Iterator input(in, in_slice);
651 Iterator _sum(sum, sum_slice);
652 Iterator output(out, in_slice);
653
654 const float16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
655 const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
656
657 execute_window_loop(in_slice, [&](const Coordinates & id)
658 {
659 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
660 const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
661
662 const float16x8_t vec_in = vld1q_f16(in_ptr);
663 const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
664
665 vst1q_f16(out_ptr, normalized_value);
666 },
667 input, output);
668 }
669 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
670}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000671#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100672
Georgios Pinitas9247c922017-06-28 18:29:47 +0100673void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
674{
675 Window window_sum(window);
676 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
677 Window sum_slice = window_sum.first_slice_window_1D();
678 Window in_slice = window.first_slice_window_1D();
679
680 do
681 {
682 Iterator input(in, in_slice);
683 Iterator _sum(sum, sum_slice);
684 Iterator output(out, in_slice);
685
686 const float sum_value = *reinterpret_cast<const float *>(_sum.ptr());
687 const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
688
689 execute_window_loop(in_slice, [&](const Coordinates & id)
690 {
691 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
692 const auto out_ptr = reinterpret_cast<float *>(output.ptr());
693
694 const float32x4_t vec_in = vld1q_f32(in_ptr);
695 const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
696
697 vst1q_f32(out_ptr, normalized_value);
698 },
699 input, output);
700 }
701 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
702}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100703} // namespace
704
705NELogits1DNormKernel::NELogits1DNormKernel()
706 : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
707{
708}
709
710void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
711{
Pablo Tellob49a7152017-07-11 16:31:35 +0100712 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100713 ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output);
714
715 // Output auto initialization if not yet initialized
716 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
717
718 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum, output);
719 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100720 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
721
722 _input = input;
723 _sum = sum;
724 _output = output;
725
726 // Configure kernel window
Georgios Pinitas9247c922017-06-28 18:29:47 +0100727 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100728
729 switch(input->info()->data_type())
730 {
731 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100732 _func = &logits_1d_norm_qs8;
733 break;
734 case DataType::QS16:
735 _func = &logits_1d_norm_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100736 break;
737 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100738 _func = &logits_1d_norm_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100739 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100740 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000741#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100742 _func = &logits_1d_norm_f16;
743 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000744#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100745 default:
746 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100747 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100748 }
749
750 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
751
752 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
753 AccessWindowStatic sum_access(sum->info(), 0, 0, 1, sum->info()->dimension(1));
754 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
755
756 update_window_and_padding(win, input_access, sum_access, output_access);
757
758 output_access.set_valid_region(win, input->info()->valid_region());
759
760 INEKernel::configure(win);
761}
762
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100763void NELogits1DNormKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100764{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100765 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100766 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
767 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
768 ARM_COMPUTE_ERROR_ON(_func == nullptr);
769
770 (*_func)(_input, _sum, _output, window);
771}