blob: 942662e84bc2e589b7a93b5b00e67f5d07b2b8eb [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/NEON/NEMath.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Utils.h"
35#include "arm_compute/core/Validate.h"
36#include "arm_compute/core/Window.h"
37
38#include <algorithm>
39#include <arm_neon.h>
40#include <cfloat>
41
42using namespace arm_compute;
43
44namespace
45{
46void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
47{
48 Window in_slice = window.first_slice_window_1D();
49
50 Window window_max(window);
51 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
52 Window max_slice = window_max.first_slice_window_1D();
53
54 do
55 {
56 Iterator input(in, in_slice);
57 Iterator output(out, max_slice);
58
59 float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
60
61 execute_window_loop(in_slice, [&](const Coordinates & id)
62 {
63 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
64 const float32x4_t current_value = vld1q_f32(in_ptr);
65 vec_max = vmaxq_f32(vec_max, current_value);
66 },
67 input);
68
69 float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
70 carry_max = vpmax_f32(carry_max, carry_max);
71
72 *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
73 }
74 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
75}
76
77void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
78{
79 Window in_slice = window.first_slice_window_1D();
80
81 Window window_max(window);
82 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
83 Window max_slice = window_max.first_slice_window_1D();
84
85 do
86 {
87 Iterator input(in, in_slice);
88 Iterator output(out, max_slice);
89
90 qint8x16_t vec_max = vdupq_n_s8(-1);
91
92 execute_window_loop(in_slice, [&](const Coordinates & id)
93 {
94 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
95 const qint8x16_t current_value = vld1q_qs8(in_ptr);
96 vec_max = vmaxq_qs8(vec_max, current_value);
97 },
98 input);
99
100 qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
101 carry_max = vpmax_qs8(carry_max, carry_max);
102 carry_max = vpmax_qs8(carry_max, carry_max);
103 carry_max = vpmax_qs8(carry_max, carry_max);
104
105 *(reinterpret_cast<int8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
106 }
107 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
108}
109} // namespace
110
111NELogits1DMaxKernel::NELogits1DMaxKernel()
112 : _func(nullptr), _border_size()
113{
114}
115
116BorderSize NELogits1DMaxKernel::border_size() const
117{
118 return _border_size;
119}
120
121void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
122{
123 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QS8);
124 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32, DataType::QS8);
125 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
126
127 const int input_width = input->info()->valid_region().shape.x();
128 unsigned int num_elems_processed_per_iteration = 0;
129
130 switch(input->info()->data_type())
131 {
132 case DataType::QS8:
133 _func = &logits_1d_max_qs8;
134 num_elems_processed_per_iteration = 16;
135 break;
136 case DataType::F32:
137 num_elems_processed_per_iteration = 4;
138 _func = &logits_1d_max_f32;
139 break;
140 default:
141 ARM_COMPUTE_ERROR("Unsupported data type.");
142 }
143
144 _input = input;
145 _output = output;
146 _border_size = BorderSize(0, input_width % num_elems_processed_per_iteration, 0, 0);
147
148 // Configure kernel window
149 constexpr unsigned int num_elems_written_per_row = 1;
150
151 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
152 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
153 AccessWindowHorizontal output_access(output->info(), 0, num_elems_written_per_row, 1.f / input_width);
154
155 update_window_and_padding(win, input_access, output_access);
156
157 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
158
159 INEKernel::configure(win);
160}
161
162void NELogits1DMaxKernel::run(const Window &window)
163{
164 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
165 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
166 ARM_COMPUTE_ERROR_ON(_func == nullptr);
167
168 (*_func)(_input, _output, window);
169}
170
171namespace
172{
173void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
174{
175 Window window_max(window);
176 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
177
178 Window max_slice = window_max.first_slice_window_1D();
179 Window in_slice = window.first_slice_window_1D();
180
181 constexpr int step = 4;
182 const int long_steps = in->info()->valid_region().shape.x() / step;
183 const int small_steps = in->info()->valid_region().shape.x() % step;
184
185 do
186 {
187 Iterator input(in, in_slice);
188 Iterator exp(out, in_slice);
189 Iterator _max(max, max_slice);
190 Iterator _sum(sum, max_slice);
191
192 // Get pointers
193 auto in_ptr = reinterpret_cast<const float *>(input.ptr());
194 auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
195
196 // Init sum to zero
197 float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
198
199 // Get max value
200 const auto max_ptr = reinterpret_cast<const float *>(_max.ptr());
201 const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
202
203 // Run neon loop
204 for(int i = 0; i < long_steps; ++i)
205 {
206 float32x4_t vec_elements = vld1q_f32(in_ptr);
207 vec_elements = vsubq_f32(vec_elements, vec_max);
208 vec_elements = vexpq_f32(vec_elements);
209
210 vst1q_f32(exp_ptr, vec_elements);
211 vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
212
213 in_ptr += step;
214 exp_ptr += step;
215 }
216
217 // Reduce sum
218 float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
219 carry_addition = vpadd_f32(carry_addition, carry_addition);
220 float sum = vget_lane_f32(carry_addition, 0);
221
222 // Run remaining elements
223 for(int i = 0; i < small_steps; ++i)
224 {
225 float element = std::exp(in_ptr[i] - *max_ptr);
226 exp_ptr[i] = element;
227 sum += element;
228 }
229
230 *(reinterpret_cast<float *>(_sum.ptr())) = sum;
231 }
232 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
233}
234void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
235{
236 Window window_max(window);
237 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
238
239 Window max_slice = window_max.first_slice_window_1D();
240 Window in_slice = window.first_slice_window_1D();
241
242 constexpr int step = 8;
243 const int long_steps = in->info()->valid_region().shape.x() / step;
244 const int small_steps = in->info()->valid_region().shape.x() % step;
245 const int fixed_point_position = in->info()->fixed_point_position();
246
247 do
248 {
249 Iterator input(in, in_slice);
250 Iterator exp(out, in_slice);
251 Iterator _max(max, max_slice);
252 Iterator _sum(sum, max_slice);
253
254 // Get pointers
255 auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
256 auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
257
258 // Init sum to zero
259 qint16x8_t vec_sum_value = vdupq_n_qs16(0);
260
261 // Get max value
262 const auto max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
263 const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
264
265 // Run neon loop
266 for(int i = 0; i < long_steps; ++i)
267 {
268 qint8x8_t vec_elements = vld1_qs8(in_ptr);
269 vec_elements = vqsub_qs8(vec_elements, vec_max);
270 vec_elements = vqexp_qs8(vec_elements, fixed_point_position);
271
272 vst1_qs8(exp_ptr, vec_elements);
273 vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
274
275 in_ptr += step;
276 exp_ptr += step;
277 }
278 // Reduce sum
279 const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
280 const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
281 const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
282 qint16_t sum = sqadd_qs16(sum0, sum1);
283
284 // Run remaining elements
285 for(int i = 0; i < small_steps; ++i)
286 {
287 qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
288 exp_ptr[i] = element;
289 sum = sqadd_qs16(sum, element);
290 }
291
292 *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
293 }
294 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
295}
296} //namespace
297
298NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
299 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr)
300{
301}
302
303void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum)
304{
305 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QS8);
306 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(max, 1, DataType::F32, DataType::QS8);
307 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32, DataType::QS8);
308 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, max, output);
309 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, max, output);
310 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum);
311
312 unsigned int num_elems_processed_per_iteration = input->info()->valid_region().shape.x();
313
314 switch(input->info()->data_type())
315 {
316 case DataType::QS8:
317 _func = &logits_1d_shift_exp_sum_qs8;
318 break;
319 case DataType::F32:
320 _func = &logits_1d_shift_exp_sum_f32;
321 break;
322 default:
323 ARM_COMPUTE_ERROR("Unsupported data type.");
324 }
325
326 _input = input;
327 _max = max;
328 _output = output;
329 _sum = sum;
330
331 // Configure kernel window
332 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
333 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
334 AccessWindowHorizontal max_access(max->info(), 0, 1);
335 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
336 AccessWindowHorizontal sum_access(sum->info(), 0, 1);
337
338 update_window_and_padding(win, input_access, max_access, output_access, sum_access);
339
340 output_access.set_valid_region(win, input->info()->valid_region());
341 sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->info()->tensor_shape()));
342
343 INEKernel::configure(win);
344}
345
346void NELogits1DShiftExpSumKernel::run(const Window &window)
347{
348 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
349 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
350 ARM_COMPUTE_ERROR_ON(_func == nullptr);
351
352 (*_func)(_input, _max, _output, _sum, window);
353}
354
355namespace
356{
357void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
358{
359 Window window_sum(window);
360 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
361 Window sum_slice = window_sum.first_slice_window_1D();
362 Window in_slice = window.first_slice_window_1D();
363
364 do
365 {
366 Iterator input(in, in_slice);
367 Iterator _sum(sum, sum_slice);
368 Iterator output(out, in_slice);
369
370 const float sum_value = *reinterpret_cast<const float *>(_sum.ptr());
371 const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
372
373 execute_window_loop(in_slice, [&](const Coordinates & id)
374 {
375 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
376 const auto out_ptr = reinterpret_cast<float *>(output.ptr());
377
378 const float32x4_t vec_in = vld1q_f32(in_ptr);
379 const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
380
381 vst1q_f32(out_ptr, normalized_value);
382 },
383 input, output);
384 }
385 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
386}
387void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
388{
389 Window window_sum(window);
390 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
391 Window sum_slice = window_sum.first_slice_window_1D();
392 Window in_slice = window.first_slice_window_1D();
393
394 const int fixed_point_position = in->info()->fixed_point_position();
395
396 do
397 {
398 Iterator input(in, in_slice);
399 Iterator _sum(sum, sum_slice);
400 Iterator output(out, in_slice);
401
402 const int8_t sum_value = *reinterpret_cast<const qint8_t *>(_sum.ptr());
403 const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
404
405 execute_window_loop(in_slice, [&](const Coordinates & id)
406 {
407 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
408 const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
409
410 const qint8x16_t vec_in = vld1q_qs8(in_ptr);
411 const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
412
413 vst1q_qs8(out_ptr, normalized_value);
414 },
415 input, output);
416 }
417 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
418}
419} // namespace
420
421NELogits1DNormKernel::NELogits1DNormKernel()
422 : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
423{
424}
425
426void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
427{
428 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QS8);
429 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, sum);
430 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output, sum);
431 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
432
433 _input = input;
434 _sum = sum;
435 _output = output;
436
437 // Configure kernel window
438 unsigned int num_elems_processed_per_iteration = 0;
439
440 switch(input->info()->data_type())
441 {
442 case DataType::QS8:
443 _func = &logits_1d_norm_qs8;
444 num_elems_processed_per_iteration = 16;
445 break;
446 case DataType::F32:
447 num_elems_processed_per_iteration = 4;
448 _func = &logits_1d_norm_f32;
449 break;
450 default:
451 ARM_COMPUTE_ERROR("Unsupported data type.");
452 }
453
454 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
455
456 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
457 AccessWindowStatic sum_access(sum->info(), 0, 0, 1, sum->info()->dimension(1));
458 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
459
460 update_window_and_padding(win, input_access, sum_access, output_access);
461
462 output_access.set_valid_region(win, input->info()->valid_region());
463
464 INEKernel::configure(win);
465}
466
467void NELogits1DNormKernel::run(const Window &window)
468{
469 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
470 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
471 ARM_COMPUTE_ERROR_ON(_func == nullptr);
472
473 (*_func)(_input, _sum, _output, window);
474}