blob: 4fed16b5fa01ed404c513b2d9ca36b4bafc2613f [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/NEMath.h"
32#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Utils.h"
34#include "arm_compute/core/Validate.h"
35#include "arm_compute/core/Window.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39#include <cfloat>
40
41using namespace arm_compute;
42
43namespace
44{
Georgios Pinitas9247c922017-06-28 18:29:47 +010045void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
46{
47 Window in_slice = window.first_slice_window_1D();
48
49 Window window_max(window);
50 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
51 Window max_slice = window_max.first_slice_window_1D();
52
53 do
54 {
55 Iterator input(in, in_slice);
56 Iterator output(out, max_slice);
57
58 qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
59
60 execute_window_loop(in_slice, [&](const Coordinates & id)
61 {
62 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
63 const qint8x16_t current_value = vld1q_qs8(in_ptr);
64 vec_max = vmaxq_qs8(vec_max, current_value);
65 },
66 input);
67
68 qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
69 carry_max = vpmax_qs8(carry_max, carry_max);
70 carry_max = vpmax_qs8(carry_max, carry_max);
71 carry_max = vpmax_qs8(carry_max, carry_max);
72
73 *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
74 }
75 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
76}
77void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
78{
79 Window in_slice = window.first_slice_window_1D();
80
81 Window window_max(window);
82 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
83 Window max_slice = window_max.first_slice_window_1D();
84
85 do
86 {
87 Iterator input(in, in_slice);
88 Iterator output(out, max_slice);
89
90 qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
91
92 execute_window_loop(in_slice, [&](const Coordinates & id)
93 {
94 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
95 const qint16x8_t current_value = vld1q_qs16(in_ptr);
96 vec_max = vmaxq_qs16(vec_max, current_value);
97 },
98 input);
99
100 qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
101 carry_max = vpmax_qs16(carry_max, carry_max);
102 carry_max = vpmax_qs16(carry_max, carry_max);
103
104 *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
105 }
106 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
107}
Pablo Tellob49a7152017-07-11 16:31:35 +0100108
109#ifdef ARM_COMPUTE_ENABLE_FP16
110void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
111{
112 Window in_slice = window.first_slice_window_1D();
113
114 Window window_max(window);
115 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
116 Window max_slice = window_max.first_slice_window_1D();
117
118 do
119 {
120 Iterator input(in, in_slice);
121 Iterator output(out, max_slice);
122
123 float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
124
125 execute_window_loop(in_slice, [&](const Coordinates & id)
126 {
127 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
128 const float16x8_t current_value = vld1q_f16(in_ptr);
129 vec_max = vmaxq_f16(vec_max, current_value);
130 },
131 input);
132
133 float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
134 carry_max = vpmax_f16(carry_max, carry_max);
135 carry_max = vpmax_f16(carry_max, carry_max);
136
137 *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
138 }
139 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
140}
141#endif /* ARM_COMPUTE_ENABLE_FP16 */
142
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100143void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
144{
145 Window in_slice = window.first_slice_window_1D();
146
147 Window window_max(window);
148 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
149 Window max_slice = window_max.first_slice_window_1D();
150
151 do
152 {
153 Iterator input(in, in_slice);
154 Iterator output(out, max_slice);
155
156 float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
157
158 execute_window_loop(in_slice, [&](const Coordinates & id)
159 {
160 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
161 const float32x4_t current_value = vld1q_f32(in_ptr);
162 vec_max = vmaxq_f32(vec_max, current_value);
163 },
164 input);
165
166 float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
167 carry_max = vpmax_f32(carry_max, carry_max);
168
169 *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
170 }
171 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
172}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100173} // namespace
174
175NELogits1DMaxKernel::NELogits1DMaxKernel()
176 : _func(nullptr), _border_size()
177{
178}
179
180BorderSize NELogits1DMaxKernel::border_size() const
181{
182 return _border_size;
183}
184
185void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
186{
Pablo Tellob49a7152017-07-11 16:31:35 +0100187 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100188 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
189
190 // Softmax across the x dimension
191 TensorShape output_shape{ input->info()->tensor_shape() };
192 output_shape.set(0, 1);
193
194 // Output auto initialization if not yet initialized
195 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
196
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100198 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
199 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100200
201 const int input_width = input->info()->valid_region().shape.x();
Georgios Pinitas9247c922017-06-28 18:29:47 +0100202 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100203
204 switch(input->info()->data_type())
205 {
206 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100207 _func = &logits_1d_max_qs8;
208 break;
209 case DataType::QS16:
210 _func = &logits_1d_max_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100211 break;
212 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100213 _func = &logits_1d_max_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100214 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100215 case DataType::F16:
216#ifdef ARM_COMPUTE_ENABLE_FP16
217 _func = &logits_1d_max_f16;
218 break;
219#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100220 default:
221 ARM_COMPUTE_ERROR("Unsupported data type.");
222 }
223
224 _input = input;
225 _output = output;
Giorgio Arenaa2611812017-07-21 10:08:48 +0100226 _border_size = BorderSize(0, num_elems_processed_per_iteration - (input_width % num_elems_processed_per_iteration), 0, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100227
228 // Configure kernel window
229 constexpr unsigned int num_elems_written_per_row = 1;
230
231 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
232 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
233 AccessWindowHorizontal output_access(output->info(), 0, num_elems_written_per_row, 1.f / input_width);
234
235 update_window_and_padding(win, input_access, output_access);
236
237 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
238
239 INEKernel::configure(win);
240}
241
242void NELogits1DMaxKernel::run(const Window &window)
243{
244 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
245 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
246 ARM_COMPUTE_ERROR_ON(_func == nullptr);
247
248 (*_func)(_input, _output, window);
249}
250
251namespace
252{
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
254{
255 Window window_max(window);
256 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
257
258 Window max_slice = window_max.first_slice_window_1D();
259 Window in_slice = window.first_slice_window_1D();
260
261 constexpr int step = 8;
262 const int long_steps = in->info()->valid_region().shape.x() / step;
263 const int small_steps = in->info()->valid_region().shape.x() % step;
264 const int fixed_point_position = in->info()->fixed_point_position();
265
266 do
267 {
268 Iterator input(in, in_slice);
269 Iterator exp(out, in_slice);
270 Iterator _max(max, max_slice);
271 Iterator _sum(sum, max_slice);
272
273 // Get pointers
274 auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
275 auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
276
277 // Init sum to zero
278 qint16x8_t vec_sum_value = vdupq_n_qs16(0);
279
280 // Get max value
281 const auto max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
282 const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
283
284 // Run neon loop
285 for(int i = 0; i < long_steps; ++i)
286 {
287 qint8x8_t vec_elements = vld1_qs8(in_ptr);
288 vec_elements = vqsub_qs8(vec_elements, vec_max);
289 vec_elements = vqexp_qs8(vec_elements, fixed_point_position);
290
291 vst1_qs8(exp_ptr, vec_elements);
292 vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
293
294 in_ptr += step;
295 exp_ptr += step;
296 }
297 // Reduce sum
298 const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
299 const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
300 const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
301 qint16_t sum = sqadd_qs16(sum0, sum1);
302
303 // Run remaining elements
304 for(int i = 0; i < small_steps; ++i)
305 {
306 qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
307 exp_ptr[i] = element;
308 sum = sqadd_qs16(sum, element);
309 }
310
311 *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
312 }
313 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
314}
Georgios Pinitas9247c922017-06-28 18:29:47 +0100315void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
316{
317 Window window_max(window);
318 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
319
320 Window max_slice = window_max.first_slice_window_1D();
321 Window in_slice = window.first_slice_window_1D();
322
323 constexpr int step = 4;
324 const int long_steps = in->info()->valid_region().shape.x() / step;
325 const int small_steps = in->info()->valid_region().shape.x() % step;
326 const int fixed_point_position = in->info()->fixed_point_position();
327
328 do
329 {
330 Iterator input(in, in_slice);
331 Iterator exp(out, in_slice);
332 Iterator _max(max, max_slice);
333 Iterator _sum(sum, max_slice);
334
335 // Get pointers
336 auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
337 auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
338
339 // Init sum to zero
340 qint32x4_t vec_sum_value = vdupq_n_qs32(0);
341
342 // Get max value
343 const auto max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
344 const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
345
346 // Run neon loop
347 for(int i = 0; i < long_steps; ++i)
348 {
349 qint16x4_t vec_elements = vld1_qs16(in_ptr);
350 vec_elements = vqsub_qs16(vec_elements, vec_max);
351 vec_elements = vqexp_qs16(vec_elements, fixed_point_position);
352
353 vst1_qs16(exp_ptr, vec_elements);
354 vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
355
356 in_ptr += step;
357 exp_ptr += step;
358 }
359 // Reduce sum
360 qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
361 qint32_t sum = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
362
363 // Run remaining elements
364 for(int i = 0; i < small_steps; ++i)
365 {
366 qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
367 exp_ptr[i] = element;
368 sum = sqadd_qs32(sum, element);
369 }
370
371 *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
372 }
373 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
374}
Pablo Tellob49a7152017-07-11 16:31:35 +0100375
376#ifdef ARM_COMPUTE_ENABLE_FP16
377void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
378{
379 Window window_max(window);
380 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
381
382 Window max_slice = window_max.first_slice_window_1D();
383 Window in_slice = window.first_slice_window_1D();
384
385 constexpr int step = 8;
386 const int long_steps = in->info()->valid_region().shape.x() / step;
387 const int small_steps = in->info()->valid_region().shape.x() % step;
388
389 do
390 {
391 Iterator input(in, in_slice);
392 Iterator exp(out, in_slice);
393 Iterator _max(max, max_slice);
394 Iterator _sum(sum, max_slice);
395
396 // Get pointers
397 auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
398 auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
399
400 // Init sum to zero
401 float16x8_t vec_sum_value = vdupq_n_f16(0);
402
403 // Get max value
404 const auto max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
405 const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
406
407 // Run neon loop
408 for(int i = 0; i < long_steps; ++i)
409 {
410 float16x8_t vec_elements = vld1q_f16(in_ptr);
411 vec_elements = vsubq_f16(vec_elements, vec_max);
412 vec_elements = vexpq_f16(vec_elements);
413
414 vst1q_f16(exp_ptr, vec_elements);
415 vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
416
417 in_ptr += step;
418 exp_ptr += step;
419 }
420 // Reduce sum
421 const float16x4_t sum_red = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
422 const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
423 float16_t sum = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
424
425 // Run remaining elements
426 for(int i = 0; i < small_steps; ++i)
427 {
428 const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr));
429 exp_ptr[i] = element;
430 sum += element;
431 }
432 *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
433 }
434 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
435}
436#endif /* ARM_COMPUTE_ENABLE_FP16 */
437
Georgios Pinitas9247c922017-06-28 18:29:47 +0100438void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
439{
440 Window window_max(window);
441 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
442
443 Window max_slice = window_max.first_slice_window_1D();
444 Window in_slice = window.first_slice_window_1D();
445
446 constexpr int step = 4;
447 const int long_steps = in->info()->valid_region().shape.x() / step;
448 const int small_steps = in->info()->valid_region().shape.x() % step;
449
450 do
451 {
452 Iterator input(in, in_slice);
453 Iterator exp(out, in_slice);
454 Iterator _max(max, max_slice);
455 Iterator _sum(sum, max_slice);
456
457 // Get pointers
458 auto in_ptr = reinterpret_cast<const float *>(input.ptr());
459 auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
460
461 // Init sum to zero
462 float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
463
464 // Get max value
465 const auto max_ptr = reinterpret_cast<const float *>(_max.ptr());
466 const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
467
468 // Run neon loop
469 for(int i = 0; i < long_steps; ++i)
470 {
471 float32x4_t vec_elements = vld1q_f32(in_ptr);
472 vec_elements = vsubq_f32(vec_elements, vec_max);
473 vec_elements = vexpq_f32(vec_elements);
474
475 vst1q_f32(exp_ptr, vec_elements);
476 vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
477
478 in_ptr += step;
479 exp_ptr += step;
480 }
481
482 // Reduce sum
483 float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
484 carry_addition = vpadd_f32(carry_addition, carry_addition);
485 float sum = vget_lane_f32(carry_addition, 0);
486
487 // Run remaining elements
488 for(int i = 0; i < small_steps; ++i)
489 {
490 float element = std::exp(in_ptr[i] - *max_ptr);
491 exp_ptr[i] = element;
492 sum += element;
493 }
494
495 *(reinterpret_cast<float *>(_sum.ptr())) = sum;
496 }
497 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
498}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100499} //namespace
500
501NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
502 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr)
503{
504}
505
506void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum)
507{
Pablo Tellob49a7152017-07-11 16:31:35 +0100508 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100509 ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
510
511 // Output auto initialization if not yet initialized
512 auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
513 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
514
515 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum);
516 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum);
517 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100518 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum);
519
520 unsigned int num_elems_processed_per_iteration = input->info()->valid_region().shape.x();
521
522 switch(input->info()->data_type())
523 {
524 case DataType::QS8:
525 _func = &logits_1d_shift_exp_sum_qs8;
526 break;
Georgios Pinitas9247c922017-06-28 18:29:47 +0100527 case DataType::QS16:
528 _func = &logits_1d_shift_exp_sum_qs16;
529 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100530 case DataType::F32:
531 _func = &logits_1d_shift_exp_sum_f32;
532 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100533 case DataType::F16:
534#ifdef ARM_COMPUTE_ENABLE_FP16
535 _func = &logits_1d_shift_exp_sum_f16;
536 break;
537#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100538 default:
539 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100540 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100541 }
542
543 _input = input;
544 _max = max;
545 _output = output;
546 _sum = sum;
547
548 // Configure kernel window
549 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
550 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
551 AccessWindowHorizontal max_access(max->info(), 0, 1);
552 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
553 AccessWindowHorizontal sum_access(sum->info(), 0, 1);
554
555 update_window_and_padding(win, input_access, max_access, output_access, sum_access);
556
557 output_access.set_valid_region(win, input->info()->valid_region());
558 sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->info()->tensor_shape()));
559
560 INEKernel::configure(win);
561}
562
563void NELogits1DShiftExpSumKernel::run(const Window &window)
564{
565 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
566 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
567 ARM_COMPUTE_ERROR_ON(_func == nullptr);
568
569 (*_func)(_input, _max, _output, _sum, window);
570}
571
572namespace
573{
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100574void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
575{
576 Window window_sum(window);
577 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
578 Window sum_slice = window_sum.first_slice_window_1D();
579 Window in_slice = window.first_slice_window_1D();
580
581 const int fixed_point_position = in->info()->fixed_point_position();
582
583 do
584 {
585 Iterator input(in, in_slice);
586 Iterator _sum(sum, sum_slice);
587 Iterator output(out, in_slice);
588
589 const int8_t sum_value = *reinterpret_cast<const qint8_t *>(_sum.ptr());
590 const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
591
592 execute_window_loop(in_slice, [&](const Coordinates & id)
593 {
594 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
595 const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
596
597 const qint8x16_t vec_in = vld1q_qs8(in_ptr);
598 const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
599
600 vst1q_qs8(out_ptr, normalized_value);
601 },
602 input, output);
603 }
604 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
605}
Georgios Pinitas9247c922017-06-28 18:29:47 +0100606void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
607{
608 Window window_sum(window);
609 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
610 Window sum_slice = window_sum.first_slice_window_1D();
611 Window in_slice = window.first_slice_window_1D();
612
613 const int fixed_point_position = in->info()->fixed_point_position();
614
615 do
616 {
617 Iterator input(in, in_slice);
618 Iterator _sum(sum, sum_slice);
619 Iterator output(out, in_slice);
620
621 const int16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
622 const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
623
624 execute_window_loop(in_slice, [&](const Coordinates & id)
625 {
626 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
627 const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
628
629 const qint16x8_t vec_in = vld1q_qs16(in_ptr);
630 const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
631
632 vst1q_qs16(out_ptr, normalized_value);
633 },
634 input, output);
635 }
636 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
637}
Pablo Tellob49a7152017-07-11 16:31:35 +0100638#ifdef ARM_COMPUTE_ENABLE_FP16
639void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
640{
641 Window window_sum(window);
642 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
643 Window sum_slice = window_sum.first_slice_window_1D();
644 Window in_slice = window.first_slice_window_1D();
645
646 do
647 {
648 Iterator input(in, in_slice);
649 Iterator _sum(sum, sum_slice);
650 Iterator output(out, in_slice);
651
652 const float16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
653 const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
654
655 execute_window_loop(in_slice, [&](const Coordinates & id)
656 {
657 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
658 const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
659
660 const float16x8_t vec_in = vld1q_f16(in_ptr);
661 const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
662
663 vst1q_f16(out_ptr, normalized_value);
664 },
665 input, output);
666 }
667 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
668}
669#endif /* ARM_COMPUTE_ENABLE_FP16 */
670
Georgios Pinitas9247c922017-06-28 18:29:47 +0100671void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
672{
673 Window window_sum(window);
674 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
675 Window sum_slice = window_sum.first_slice_window_1D();
676 Window in_slice = window.first_slice_window_1D();
677
678 do
679 {
680 Iterator input(in, in_slice);
681 Iterator _sum(sum, sum_slice);
682 Iterator output(out, in_slice);
683
684 const float sum_value = *reinterpret_cast<const float *>(_sum.ptr());
685 const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
686
687 execute_window_loop(in_slice, [&](const Coordinates & id)
688 {
689 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
690 const auto out_ptr = reinterpret_cast<float *>(output.ptr());
691
692 const float32x4_t vec_in = vld1q_f32(in_ptr);
693 const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
694
695 vst1q_f32(out_ptr, normalized_value);
696 },
697 input, output);
698 }
699 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
700}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100701} // namespace
702
703NELogits1DNormKernel::NELogits1DNormKernel()
704 : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
705{
706}
707
708void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
709{
Pablo Tellob49a7152017-07-11 16:31:35 +0100710 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100711 ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output);
712
713 // Output auto initialization if not yet initialized
714 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
715
716 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum, output);
717 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100718 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
719
720 _input = input;
721 _sum = sum;
722 _output = output;
723
724 // Configure kernel window
Georgios Pinitas9247c922017-06-28 18:29:47 +0100725 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100726
727 switch(input->info()->data_type())
728 {
729 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100730 _func = &logits_1d_norm_qs8;
731 break;
732 case DataType::QS16:
733 _func = &logits_1d_norm_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100734 break;
735 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100736 _func = &logits_1d_norm_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100737 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100738 case DataType::F16:
739#ifdef ARM_COMPUTE_ENABLE_FP16
740 _func = &logits_1d_norm_f16;
741 break;
742#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100743 default:
744 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100745 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100746 }
747
748 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
749
750 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
751 AccessWindowStatic sum_access(sum->info(), 0, 0, 1, sum->info()->dimension(1));
752 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
753
754 update_window_and_padding(win, input_access, sum_access, output_access);
755
756 output_access.set_valid_region(win, input->info()->valid_region());
757
758 INEKernel::configure(win);
759}
760
761void NELogits1DNormKernel::run(const Window &window)
762{
763 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
764 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
765 ARM_COMPUTE_ERROR_ON(_func == nullptr);
766
767 (*_func)(_input, _sum, _output, window);
768}