blob: 79fcba1dfbd9fc66c768b6cf331e44ef2a949de7 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Utils.h"
35#include "arm_compute/core/Validate.h"
36#include "arm_compute/core/Window.h"
37
38#include <algorithm>
39#include <arm_neon.h>
40#include <cfloat>
41
42using namespace arm_compute;
43
44namespace
45{
Georgios Pinitas9247c922017-06-28 18:29:47 +010046void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
47{
48 Window in_slice = window.first_slice_window_1D();
49
50 Window window_max(window);
51 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
52 Window max_slice = window_max.first_slice_window_1D();
53
54 do
55 {
56 Iterator input(in, in_slice);
57 Iterator output(out, max_slice);
58
59 qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
60
61 execute_window_loop(in_slice, [&](const Coordinates & id)
62 {
63 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
64 const qint8x16_t current_value = vld1q_qs8(in_ptr);
65 vec_max = vmaxq_qs8(vec_max, current_value);
66 },
67 input);
68
69 qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
70 carry_max = vpmax_qs8(carry_max, carry_max);
71 carry_max = vpmax_qs8(carry_max, carry_max);
72 carry_max = vpmax_qs8(carry_max, carry_max);
73
74 *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
75 }
76 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
77}
78void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
79{
80 Window in_slice = window.first_slice_window_1D();
81
82 Window window_max(window);
83 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
84 Window max_slice = window_max.first_slice_window_1D();
85
86 do
87 {
88 Iterator input(in, in_slice);
89 Iterator output(out, max_slice);
90
91 qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
92
93 execute_window_loop(in_slice, [&](const Coordinates & id)
94 {
95 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
96 const qint16x8_t current_value = vld1q_qs16(in_ptr);
97 vec_max = vmaxq_qs16(vec_max, current_value);
98 },
99 input);
100
101 qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
102 carry_max = vpmax_qs16(carry_max, carry_max);
103 carry_max = vpmax_qs16(carry_max, carry_max);
104
105 *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
106 }
107 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
108}
Pablo Tellob49a7152017-07-11 16:31:35 +0100109
110#ifdef ARM_COMPUTE_ENABLE_FP16
111void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
112{
113 Window in_slice = window.first_slice_window_1D();
114
115 Window window_max(window);
116 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
117 Window max_slice = window_max.first_slice_window_1D();
118
119 do
120 {
121 Iterator input(in, in_slice);
122 Iterator output(out, max_slice);
123
124 float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
125
126 execute_window_loop(in_slice, [&](const Coordinates & id)
127 {
128 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
129 const float16x8_t current_value = vld1q_f16(in_ptr);
130 vec_max = vmaxq_f16(vec_max, current_value);
131 },
132 input);
133
134 float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
135 carry_max = vpmax_f16(carry_max, carry_max);
136 carry_max = vpmax_f16(carry_max, carry_max);
137
138 *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
139 }
140 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
141}
142#endif /* ARM_COMPUTE_ENABLE_FP16 */
143
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100144void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
145{
146 Window in_slice = window.first_slice_window_1D();
147
148 Window window_max(window);
149 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
150 Window max_slice = window_max.first_slice_window_1D();
151
152 do
153 {
154 Iterator input(in, in_slice);
155 Iterator output(out, max_slice);
156
157 float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
158
159 execute_window_loop(in_slice, [&](const Coordinates & id)
160 {
161 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
162 const float32x4_t current_value = vld1q_f32(in_ptr);
163 vec_max = vmaxq_f32(vec_max, current_value);
164 },
165 input);
166
167 float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
168 carry_max = vpmax_f32(carry_max, carry_max);
169
170 *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
171 }
172 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
173}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100174} // namespace
175
176NELogits1DMaxKernel::NELogits1DMaxKernel()
177 : _func(nullptr), _border_size()
178{
179}
180
181BorderSize NELogits1DMaxKernel::border_size() const
182{
183 return _border_size;
184}
185
186void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
187{
Pablo Tellob49a7152017-07-11 16:31:35 +0100188 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100189 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
190
191 // Softmax across the x dimension
192 TensorShape output_shape{ input->info()->tensor_shape() };
193 output_shape.set(0, 1);
194
195 // Output auto initialization if not yet initialized
196 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
197
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100198 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100199 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
200 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201
202 const int input_width = input->info()->valid_region().shape.x();
Georgios Pinitas9247c922017-06-28 18:29:47 +0100203 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204
205 switch(input->info()->data_type())
206 {
207 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100208 _func = &logits_1d_max_qs8;
209 break;
210 case DataType::QS16:
211 _func = &logits_1d_max_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212 break;
213 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100214 _func = &logits_1d_max_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100215 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100216 case DataType::F16:
217#ifdef ARM_COMPUTE_ENABLE_FP16
218 _func = &logits_1d_max_f16;
219 break;
220#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100221 default:
222 ARM_COMPUTE_ERROR("Unsupported data type.");
223 }
224
225 _input = input;
226 _output = output;
227 _border_size = BorderSize(0, input_width % num_elems_processed_per_iteration, 0, 0);
228
229 // Configure kernel window
230 constexpr unsigned int num_elems_written_per_row = 1;
231
232 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
233 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
234 AccessWindowHorizontal output_access(output->info(), 0, num_elems_written_per_row, 1.f / input_width);
235
236 update_window_and_padding(win, input_access, output_access);
237
238 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
239
240 INEKernel::configure(win);
241}
242
243void NELogits1DMaxKernel::run(const Window &window)
244{
245 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
246 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
247 ARM_COMPUTE_ERROR_ON(_func == nullptr);
248
249 (*_func)(_input, _output, window);
250}
251
252namespace
253{
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100254void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
255{
256 Window window_max(window);
257 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
258
259 Window max_slice = window_max.first_slice_window_1D();
260 Window in_slice = window.first_slice_window_1D();
261
262 constexpr int step = 8;
263 const int long_steps = in->info()->valid_region().shape.x() / step;
264 const int small_steps = in->info()->valid_region().shape.x() % step;
265 const int fixed_point_position = in->info()->fixed_point_position();
266
267 do
268 {
269 Iterator input(in, in_slice);
270 Iterator exp(out, in_slice);
271 Iterator _max(max, max_slice);
272 Iterator _sum(sum, max_slice);
273
274 // Get pointers
275 auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
276 auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
277
278 // Init sum to zero
279 qint16x8_t vec_sum_value = vdupq_n_qs16(0);
280
281 // Get max value
282 const auto max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
283 const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
284
285 // Run neon loop
286 for(int i = 0; i < long_steps; ++i)
287 {
288 qint8x8_t vec_elements = vld1_qs8(in_ptr);
289 vec_elements = vqsub_qs8(vec_elements, vec_max);
290 vec_elements = vqexp_qs8(vec_elements, fixed_point_position);
291
292 vst1_qs8(exp_ptr, vec_elements);
293 vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
294
295 in_ptr += step;
296 exp_ptr += step;
297 }
298 // Reduce sum
299 const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
300 const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
301 const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
302 qint16_t sum = sqadd_qs16(sum0, sum1);
303
304 // Run remaining elements
305 for(int i = 0; i < small_steps; ++i)
306 {
307 qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
308 exp_ptr[i] = element;
309 sum = sqadd_qs16(sum, element);
310 }
311
312 *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
313 }
314 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
315}
Georgios Pinitas9247c922017-06-28 18:29:47 +0100316void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
317{
318 Window window_max(window);
319 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
320
321 Window max_slice = window_max.first_slice_window_1D();
322 Window in_slice = window.first_slice_window_1D();
323
324 constexpr int step = 4;
325 const int long_steps = in->info()->valid_region().shape.x() / step;
326 const int small_steps = in->info()->valid_region().shape.x() % step;
327 const int fixed_point_position = in->info()->fixed_point_position();
328
329 do
330 {
331 Iterator input(in, in_slice);
332 Iterator exp(out, in_slice);
333 Iterator _max(max, max_slice);
334 Iterator _sum(sum, max_slice);
335
336 // Get pointers
337 auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
338 auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
339
340 // Init sum to zero
341 qint32x4_t vec_sum_value = vdupq_n_qs32(0);
342
343 // Get max value
344 const auto max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
345 const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
346
347 // Run neon loop
348 for(int i = 0; i < long_steps; ++i)
349 {
350 qint16x4_t vec_elements = vld1_qs16(in_ptr);
351 vec_elements = vqsub_qs16(vec_elements, vec_max);
352 vec_elements = vqexp_qs16(vec_elements, fixed_point_position);
353
354 vst1_qs16(exp_ptr, vec_elements);
355 vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
356
357 in_ptr += step;
358 exp_ptr += step;
359 }
360 // Reduce sum
361 qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
362 qint32_t sum = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
363
364 // Run remaining elements
365 for(int i = 0; i < small_steps; ++i)
366 {
367 qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
368 exp_ptr[i] = element;
369 sum = sqadd_qs32(sum, element);
370 }
371
372 *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
373 }
374 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
375}
Pablo Tellob49a7152017-07-11 16:31:35 +0100376
377#ifdef ARM_COMPUTE_ENABLE_FP16
378void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
379{
380 Window window_max(window);
381 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
382
383 Window max_slice = window_max.first_slice_window_1D();
384 Window in_slice = window.first_slice_window_1D();
385
386 constexpr int step = 8;
387 const int long_steps = in->info()->valid_region().shape.x() / step;
388 const int small_steps = in->info()->valid_region().shape.x() % step;
389
390 do
391 {
392 Iterator input(in, in_slice);
393 Iterator exp(out, in_slice);
394 Iterator _max(max, max_slice);
395 Iterator _sum(sum, max_slice);
396
397 // Get pointers
398 auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
399 auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
400
401 // Init sum to zero
402 float16x8_t vec_sum_value = vdupq_n_f16(0);
403
404 // Get max value
405 const auto max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
406 const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
407
408 // Run neon loop
409 for(int i = 0; i < long_steps; ++i)
410 {
411 float16x8_t vec_elements = vld1q_f16(in_ptr);
412 vec_elements = vsubq_f16(vec_elements, vec_max);
413 vec_elements = vexpq_f16(vec_elements);
414
415 vst1q_f16(exp_ptr, vec_elements);
416 vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
417
418 in_ptr += step;
419 exp_ptr += step;
420 }
421 // Reduce sum
422 const float16x4_t sum_red = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
423 const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
424 float16_t sum = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
425
426 // Run remaining elements
427 for(int i = 0; i < small_steps; ++i)
428 {
429 const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr));
430 exp_ptr[i] = element;
431 sum += element;
432 }
433 *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
434 }
435 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
436}
437#endif /* ARM_COMPUTE_ENABLE_FP16 */
438
Georgios Pinitas9247c922017-06-28 18:29:47 +0100439void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
440{
441 Window window_max(window);
442 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
443
444 Window max_slice = window_max.first_slice_window_1D();
445 Window in_slice = window.first_slice_window_1D();
446
447 constexpr int step = 4;
448 const int long_steps = in->info()->valid_region().shape.x() / step;
449 const int small_steps = in->info()->valid_region().shape.x() % step;
450
451 do
452 {
453 Iterator input(in, in_slice);
454 Iterator exp(out, in_slice);
455 Iterator _max(max, max_slice);
456 Iterator _sum(sum, max_slice);
457
458 // Get pointers
459 auto in_ptr = reinterpret_cast<const float *>(input.ptr());
460 auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
461
462 // Init sum to zero
463 float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
464
465 // Get max value
466 const auto max_ptr = reinterpret_cast<const float *>(_max.ptr());
467 const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
468
469 // Run neon loop
470 for(int i = 0; i < long_steps; ++i)
471 {
472 float32x4_t vec_elements = vld1q_f32(in_ptr);
473 vec_elements = vsubq_f32(vec_elements, vec_max);
474 vec_elements = vexpq_f32(vec_elements);
475
476 vst1q_f32(exp_ptr, vec_elements);
477 vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
478
479 in_ptr += step;
480 exp_ptr += step;
481 }
482
483 // Reduce sum
484 float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
485 carry_addition = vpadd_f32(carry_addition, carry_addition);
486 float sum = vget_lane_f32(carry_addition, 0);
487
488 // Run remaining elements
489 for(int i = 0; i < small_steps; ++i)
490 {
491 float element = std::exp(in_ptr[i] - *max_ptr);
492 exp_ptr[i] = element;
493 sum += element;
494 }
495
496 *(reinterpret_cast<float *>(_sum.ptr())) = sum;
497 }
498 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
499}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500} //namespace
501
502NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
503 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr)
504{
505}
506
507void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum)
508{
Pablo Tellob49a7152017-07-11 16:31:35 +0100509 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100510 ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
511
512 // Output auto initialization if not yet initialized
513 auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
514 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
515
516 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum);
517 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum);
518 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100519 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum);
520
521 unsigned int num_elems_processed_per_iteration = input->info()->valid_region().shape.x();
522
523 switch(input->info()->data_type())
524 {
525 case DataType::QS8:
526 _func = &logits_1d_shift_exp_sum_qs8;
527 break;
Georgios Pinitas9247c922017-06-28 18:29:47 +0100528 case DataType::QS16:
529 _func = &logits_1d_shift_exp_sum_qs16;
530 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100531 case DataType::F32:
532 _func = &logits_1d_shift_exp_sum_f32;
533 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100534 case DataType::F16:
535#ifdef ARM_COMPUTE_ENABLE_FP16
536 _func = &logits_1d_shift_exp_sum_f16;
537 break;
538#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539 default:
540 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100541 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100542 }
543
544 _input = input;
545 _max = max;
546 _output = output;
547 _sum = sum;
548
549 // Configure kernel window
550 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
551 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
552 AccessWindowHorizontal max_access(max->info(), 0, 1);
553 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
554 AccessWindowHorizontal sum_access(sum->info(), 0, 1);
555
556 update_window_and_padding(win, input_access, max_access, output_access, sum_access);
557
558 output_access.set_valid_region(win, input->info()->valid_region());
559 sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->info()->tensor_shape()));
560
561 INEKernel::configure(win);
562}
563
564void NELogits1DShiftExpSumKernel::run(const Window &window)
565{
566 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
567 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
568 ARM_COMPUTE_ERROR_ON(_func == nullptr);
569
570 (*_func)(_input, _max, _output, _sum, window);
571}
572
573namespace
574{
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100575void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
576{
577 Window window_sum(window);
578 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
579 Window sum_slice = window_sum.first_slice_window_1D();
580 Window in_slice = window.first_slice_window_1D();
581
582 const int fixed_point_position = in->info()->fixed_point_position();
583
584 do
585 {
586 Iterator input(in, in_slice);
587 Iterator _sum(sum, sum_slice);
588 Iterator output(out, in_slice);
589
590 const int8_t sum_value = *reinterpret_cast<const qint8_t *>(_sum.ptr());
591 const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
592
593 execute_window_loop(in_slice, [&](const Coordinates & id)
594 {
595 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
596 const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
597
598 const qint8x16_t vec_in = vld1q_qs8(in_ptr);
599 const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
600
601 vst1q_qs8(out_ptr, normalized_value);
602 },
603 input, output);
604 }
605 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
606}
Georgios Pinitas9247c922017-06-28 18:29:47 +0100607void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
608{
609 Window window_sum(window);
610 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
611 Window sum_slice = window_sum.first_slice_window_1D();
612 Window in_slice = window.first_slice_window_1D();
613
614 const int fixed_point_position = in->info()->fixed_point_position();
615
616 do
617 {
618 Iterator input(in, in_slice);
619 Iterator _sum(sum, sum_slice);
620 Iterator output(out, in_slice);
621
622 const int16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
623 const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
624
625 execute_window_loop(in_slice, [&](const Coordinates & id)
626 {
627 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
628 const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
629
630 const qint16x8_t vec_in = vld1q_qs16(in_ptr);
631 const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
632
633 vst1q_qs16(out_ptr, normalized_value);
634 },
635 input, output);
636 }
637 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
638}
Pablo Tellob49a7152017-07-11 16:31:35 +0100639#ifdef ARM_COMPUTE_ENABLE_FP16
640void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
641{
642 Window window_sum(window);
643 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
644 Window sum_slice = window_sum.first_slice_window_1D();
645 Window in_slice = window.first_slice_window_1D();
646
647 do
648 {
649 Iterator input(in, in_slice);
650 Iterator _sum(sum, sum_slice);
651 Iterator output(out, in_slice);
652
653 const float16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
654 const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
655
656 execute_window_loop(in_slice, [&](const Coordinates & id)
657 {
658 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
659 const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
660
661 const float16x8_t vec_in = vld1q_f16(in_ptr);
662 const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
663
664 vst1q_f16(out_ptr, normalized_value);
665 },
666 input, output);
667 }
668 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
669}
670#endif /* ARM_COMPUTE_ENABLE_FP16 */
671
Georgios Pinitas9247c922017-06-28 18:29:47 +0100672void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
673{
674 Window window_sum(window);
675 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
676 Window sum_slice = window_sum.first_slice_window_1D();
677 Window in_slice = window.first_slice_window_1D();
678
679 do
680 {
681 Iterator input(in, in_slice);
682 Iterator _sum(sum, sum_slice);
683 Iterator output(out, in_slice);
684
685 const float sum_value = *reinterpret_cast<const float *>(_sum.ptr());
686 const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
687
688 execute_window_loop(in_slice, [&](const Coordinates & id)
689 {
690 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
691 const auto out_ptr = reinterpret_cast<float *>(output.ptr());
692
693 const float32x4_t vec_in = vld1q_f32(in_ptr);
694 const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
695
696 vst1q_f32(out_ptr, normalized_value);
697 },
698 input, output);
699 }
700 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
701}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100702} // namespace
703
704NELogits1DNormKernel::NELogits1DNormKernel()
705 : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
706{
707}
708
709void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
710{
Pablo Tellob49a7152017-07-11 16:31:35 +0100711 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100712 ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output);
713
714 // Output auto initialization if not yet initialized
715 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
716
717 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum, output);
718 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100719 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
720
721 _input = input;
722 _sum = sum;
723 _output = output;
724
725 // Configure kernel window
Georgios Pinitas9247c922017-06-28 18:29:47 +0100726 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100727
728 switch(input->info()->data_type())
729 {
730 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100731 _func = &logits_1d_norm_qs8;
732 break;
733 case DataType::QS16:
734 _func = &logits_1d_norm_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100735 break;
736 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100737 _func = &logits_1d_norm_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100738 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100739 case DataType::F16:
740#ifdef ARM_COMPUTE_ENABLE_FP16
741 _func = &logits_1d_norm_f16;
742 break;
743#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100744 default:
745 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100746 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100747 }
748
749 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
750
751 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
752 AccessWindowStatic sum_access(sum->info(), 0, 0, 1, sum->info()->dimension(1));
753 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
754
755 update_window_and_padding(win, input_access, sum_access, output_access);
756
757 output_access.set_valid_region(win, input->info()->valid_region());
758
759 INEKernel::configure(win);
760}
761
762void NELogits1DNormKernel::run(const Window &window)
763{
764 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
765 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
766 ARM_COMPUTE_ERROR_ON(_func == nullptr);
767
768 (*_func)(_input, _sum, _output, window);
769}