blob: b13fb0e87cebfd328e0f4a38098a670adf5473e9 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/NEMath.h"
32#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Utils.h"
34#include "arm_compute/core/Validate.h"
35#include "arm_compute/core/Window.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39#include <cfloat>
40
41using namespace arm_compute;
42
43namespace
44{
Michalis Spyrouafa5d812017-11-30 14:25:57 +000045Status validate_arguments_logits_1d_max(const ITensorInfo *input, const ITensorInfo *output)
46{
47 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
48 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
49
50 // Checks performed when output is configured
51 if(output->total_size() != 0)
52 {
53 // Softmax across the x dimension
54 TensorShape output_shape{ input->tensor_shape() };
55 output_shape.set(0, 1);
56
57 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
58 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
59 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
60 }
61
62 return Status{};
63}
64
65std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo *input, ITensorInfo *output)
66{
67 // Configure kernel window
68 constexpr unsigned int num_elems_written_per_row = 1;
69 const int input_width = input->valid_region().shape.x();
70
71 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
72 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
73 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
74 bool window_changed = false;
75
76 if(output->total_size() != 0)
77 {
78 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_row, 1.f / input_width);
79 window_changed = update_window_and_padding(win, input_access, output_access);
80 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
81 }
82 else
83 {
84 window_changed = update_window_and_padding(win, input_access);
85 }
86
87 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
88 return std::make_pair(err, win);
89}
90
91Status validate_arguments_logits_1d_shift_exp_sum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
92{
93 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, max, sum, output);
94 ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input->data_type()));
95 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
96
97 // Checks performed when output is configured
98 if(output->total_size() != 0)
99 {
100 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
101 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
102 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
103 }
104
105 // Checks performed when sum is configured
106 if(sum->total_size() != 0)
107 {
108 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max, sum);
109 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
110 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max, sum);
111 }
112
113 return Status{};
114}
115
116std::pair<Status, Window> validate_and_configure_window_logits_1d_shift_exp_sum(ITensorInfo *input, ITensorInfo *max, ITensorInfo *output, ITensorInfo *sum)
117{
118 unsigned int num_elems_processed_per_iteration = input->valid_region().shape.x();
119
120 // Configure kernel window
121 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
122 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
123 AccessWindowHorizontal max_access(max, 0, 1);
124 AccessWindowHorizontal sum_access(sum, 0, 1);
125 bool window_changed = false;
126
127 if(output->total_size() != 0)
128 {
129 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
130 window_changed = update_window_and_padding(win, input_access, max_access, output_access, sum_access);
131 output_access.set_valid_region(win, input->valid_region());
132 }
133 else
134 {
135 window_changed = update_window_and_padding(win, input_access, max_access, sum_access);
136 }
137
138 sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->tensor_shape()));
139
140 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
141 return std::make_pair(err, win);
142}
143
144Status validate_arguments_logits_1d_norm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
145{
146 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
147 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
148 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
149 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
150
151 // Checks performed when output is configured
152 if(output->total_size() != 0)
153 {
154 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
155 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
156 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
157 }
158
159 return Status{};
160}
161
162std::pair<Status, Window> validate_and_configure_window_logits_1d_norm(ITensorInfo *input, ITensorInfo *sum, ITensorInfo *output)
163{
164 // Configure kernel window
165 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
166 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
167
168 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
169 AccessWindowStatic sum_access(sum, 0, 0, 1, sum->dimension(1));
170 bool window_changed = false;
171
172 if(output->total_size() != 0)
173 {
174 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
175
176 window_changed = update_window_and_padding(win, input_access, sum_access, output_access);
177
178 output_access.set_valid_region(win, input->valid_region());
179 }
180 else
181 {
182 window_changed = update_window_and_padding(win, input_access, sum_access);
183 }
184 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
185 return std::make_pair(err, win);
186}
187
Georgios Pinitas9247c922017-06-28 18:29:47 +0100188void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
189{
190 Window in_slice = window.first_slice_window_1D();
191
192 Window window_max(window);
193 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
194 Window max_slice = window_max.first_slice_window_1D();
195
196 do
197 {
198 Iterator input(in, in_slice);
199 Iterator output(out, max_slice);
200
201 qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
202
203 execute_window_loop(in_slice, [&](const Coordinates & id)
204 {
205 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
206 const qint8x16_t current_value = vld1q_qs8(in_ptr);
207 vec_max = vmaxq_qs8(vec_max, current_value);
208 },
209 input);
210
211 qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
212 carry_max = vpmax_qs8(carry_max, carry_max);
213 carry_max = vpmax_qs8(carry_max, carry_max);
214 carry_max = vpmax_qs8(carry_max, carry_max);
215
216 *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
217 }
218 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
219}
220void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
221{
222 Window in_slice = window.first_slice_window_1D();
223
224 Window window_max(window);
225 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
226 Window max_slice = window_max.first_slice_window_1D();
227
228 do
229 {
230 Iterator input(in, in_slice);
231 Iterator output(out, max_slice);
232
233 qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
234
235 execute_window_loop(in_slice, [&](const Coordinates & id)
236 {
237 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
238 const qint16x8_t current_value = vld1q_qs16(in_ptr);
239 vec_max = vmaxq_qs16(vec_max, current_value);
240 },
241 input);
242
243 qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
244 carry_max = vpmax_qs16(carry_max, carry_max);
245 carry_max = vpmax_qs16(carry_max, carry_max);
246
247 *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
248 }
249 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
250}
Pablo Tellob49a7152017-07-11 16:31:35 +0100251
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000252#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100253void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
254{
255 Window in_slice = window.first_slice_window_1D();
256
257 Window window_max(window);
258 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
259 Window max_slice = window_max.first_slice_window_1D();
260
261 do
262 {
263 Iterator input(in, in_slice);
264 Iterator output(out, max_slice);
265
266 float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
267
268 execute_window_loop(in_slice, [&](const Coordinates & id)
269 {
270 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
271 const float16x8_t current_value = vld1q_f16(in_ptr);
272 vec_max = vmaxq_f16(vec_max, current_value);
273 },
274 input);
275
276 float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
277 carry_max = vpmax_f16(carry_max, carry_max);
278 carry_max = vpmax_f16(carry_max, carry_max);
279
280 *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
281 }
282 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
283}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000284#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100285
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
287{
288 Window in_slice = window.first_slice_window_1D();
289
290 Window window_max(window);
291 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
292 Window max_slice = window_max.first_slice_window_1D();
293
294 do
295 {
296 Iterator input(in, in_slice);
297 Iterator output(out, max_slice);
298
299 float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
300
301 execute_window_loop(in_slice, [&](const Coordinates & id)
302 {
303 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
304 const float32x4_t current_value = vld1q_f32(in_ptr);
305 vec_max = vmaxq_f32(vec_max, current_value);
306 },
307 input);
308
309 float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
310 carry_max = vpmax_f32(carry_max, carry_max);
311
312 *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
313 }
314 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
315}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100316} // namespace
317
318NELogits1DMaxKernel::NELogits1DMaxKernel()
319 : _func(nullptr), _border_size()
320{
321}
322
323BorderSize NELogits1DMaxKernel::border_size() const
324{
325 return _border_size;
326}
327
328void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
329{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000330 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100331
332 // Softmax across the x dimension
333 TensorShape output_shape{ input->info()->tensor_shape() };
334 output_shape.set(0, 1);
335
336 // Output auto initialization if not yet initialized
337 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
338
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000339 // Perform validation step
340 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(input->info(), output->info()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100341
342 const int input_width = input->info()->valid_region().shape.x();
Georgios Pinitas9247c922017-06-28 18:29:47 +0100343 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100344
345 switch(input->info()->data_type())
346 {
347 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100348 _func = &logits_1d_max_qs8;
349 break;
350 case DataType::QS16:
351 _func = &logits_1d_max_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100352 break;
353 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100354 _func = &logits_1d_max_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100355 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100356 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000357#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100358 _func = &logits_1d_max_f16;
359 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000360#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100361 default:
362 ARM_COMPUTE_ERROR("Unsupported data type.");
363 }
364
365 _input = input;
366 _output = output;
Giorgio Arenaa2611812017-07-21 10:08:48 +0100367 _border_size = BorderSize(0, num_elems_processed_per_iteration - (input_width % num_elems_processed_per_iteration), 0, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100368
369 // Configure kernel window
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000370 auto win_config = validate_and_configure_window_logits_1d_max(input->info(), output->info());
371 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
372 INEKernel::configure(win_config.second);
373}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100374
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000375Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
376{
377 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(input, output));
378 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(input->clone().get(), output->clone().get()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100379
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000380 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100381}
382
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100383void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100384{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100385 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100386 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
387 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
388 ARM_COMPUTE_ERROR_ON(_func == nullptr);
389
390 (*_func)(_input, _output, window);
391}
392
393namespace
394{
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100395void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100396{
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100397 ARM_COMPUTE_UNUSED(beta);
398
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100399 Window window_max(window);
400 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
401
402 Window max_slice = window_max.first_slice_window_1D();
403 Window in_slice = window.first_slice_window_1D();
404
405 constexpr int step = 8;
406 const int long_steps = in->info()->valid_region().shape.x() / step;
407 const int small_steps = in->info()->valid_region().shape.x() % step;
408 const int fixed_point_position = in->info()->fixed_point_position();
409
410 do
411 {
412 Iterator input(in, in_slice);
413 Iterator exp(out, in_slice);
414 Iterator _max(max, max_slice);
415 Iterator _sum(sum, max_slice);
416
417 // Get pointers
418 auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
419 auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
420
421 // Init sum to zero
422 qint16x8_t vec_sum_value = vdupq_n_qs16(0);
423
424 // Get max value
425 const auto max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
426 const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
427
428 // Run neon loop
429 for(int i = 0; i < long_steps; ++i)
430 {
431 qint8x8_t vec_elements = vld1_qs8(in_ptr);
432 vec_elements = vqsub_qs8(vec_elements, vec_max);
433 vec_elements = vqexp_qs8(vec_elements, fixed_point_position);
434
435 vst1_qs8(exp_ptr, vec_elements);
436 vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
437
438 in_ptr += step;
439 exp_ptr += step;
440 }
441 // Reduce sum
442 const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
443 const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
444 const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
445 qint16_t sum = sqadd_qs16(sum0, sum1);
446
447 // Run remaining elements
448 for(int i = 0; i < small_steps; ++i)
449 {
450 qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
451 exp_ptr[i] = element;
452 sum = sqadd_qs16(sum, element);
453 }
454
455 *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
456 }
457 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
458}
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100459void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100460{
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100461 ARM_COMPUTE_UNUSED(beta);
462
Georgios Pinitas9247c922017-06-28 18:29:47 +0100463 Window window_max(window);
464 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
465
466 Window max_slice = window_max.first_slice_window_1D();
467 Window in_slice = window.first_slice_window_1D();
468
469 constexpr int step = 4;
470 const int long_steps = in->info()->valid_region().shape.x() / step;
471 const int small_steps = in->info()->valid_region().shape.x() % step;
472 const int fixed_point_position = in->info()->fixed_point_position();
473
474 do
475 {
476 Iterator input(in, in_slice);
477 Iterator exp(out, in_slice);
478 Iterator _max(max, max_slice);
479 Iterator _sum(sum, max_slice);
480
481 // Get pointers
482 auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
483 auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
484
485 // Init sum to zero
486 qint32x4_t vec_sum_value = vdupq_n_qs32(0);
487
488 // Get max value
489 const auto max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
490 const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
491
492 // Run neon loop
493 for(int i = 0; i < long_steps; ++i)
494 {
495 qint16x4_t vec_elements = vld1_qs16(in_ptr);
496 vec_elements = vqsub_qs16(vec_elements, vec_max);
497 vec_elements = vqexp_qs16(vec_elements, fixed_point_position);
498
499 vst1_qs16(exp_ptr, vec_elements);
500 vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
501
502 in_ptr += step;
503 exp_ptr += step;
504 }
505 // Reduce sum
506 qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
507 qint32_t sum = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
508
509 // Run remaining elements
510 for(int i = 0; i < small_steps; ++i)
511 {
512 qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
513 exp_ptr[i] = element;
514 sum = sqadd_qs32(sum, element);
515 }
516
517 *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
518 }
519 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
520}
Pablo Tellob49a7152017-07-11 16:31:35 +0100521
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000522#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100523void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
Pablo Tellob49a7152017-07-11 16:31:35 +0100524{
525 Window window_max(window);
526 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
527
528 Window max_slice = window_max.first_slice_window_1D();
529 Window in_slice = window.first_slice_window_1D();
530
531 constexpr int step = 8;
532 const int long_steps = in->info()->valid_region().shape.x() / step;
533 const int small_steps = in->info()->valid_region().shape.x() % step;
534
535 do
536 {
537 Iterator input(in, in_slice);
538 Iterator exp(out, in_slice);
539 Iterator _max(max, max_slice);
540 Iterator _sum(sum, max_slice);
541
542 // Get pointers
543 auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
544 auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
545
546 // Init sum to zero
547 float16x8_t vec_sum_value = vdupq_n_f16(0);
548
549 // Get max value
550 const auto max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
551 const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
552
553 // Run neon loop
554 for(int i = 0; i < long_steps; ++i)
555 {
556 float16x8_t vec_elements = vld1q_f16(in_ptr);
557 vec_elements = vsubq_f16(vec_elements, vec_max);
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100558 vec_elements = vmulq_n_f16(vec_elements, beta);
Pablo Tellob49a7152017-07-11 16:31:35 +0100559 vec_elements = vexpq_f16(vec_elements);
560
561 vst1q_f16(exp_ptr, vec_elements);
562 vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
563
564 in_ptr += step;
565 exp_ptr += step;
566 }
567 // Reduce sum
568 const float16x4_t sum_red = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
569 const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
570 float16_t sum = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
571
572 // Run remaining elements
573 for(int i = 0; i < small_steps; ++i)
574 {
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100575 const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr) * beta);
Pablo Tellob49a7152017-07-11 16:31:35 +0100576 exp_ptr[i] = element;
577 sum += element;
578 }
579 *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
580 }
581 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
582}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000583#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100584
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100585void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100586{
587 Window window_max(window);
588 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
589
590 Window max_slice = window_max.first_slice_window_1D();
591 Window in_slice = window.first_slice_window_1D();
592
593 constexpr int step = 4;
594 const int long_steps = in->info()->valid_region().shape.x() / step;
595 const int small_steps = in->info()->valid_region().shape.x() % step;
596
597 do
598 {
599 Iterator input(in, in_slice);
600 Iterator exp(out, in_slice);
601 Iterator _max(max, max_slice);
602 Iterator _sum(sum, max_slice);
603
604 // Get pointers
605 auto in_ptr = reinterpret_cast<const float *>(input.ptr());
606 auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
607
608 // Init sum to zero
609 float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
610
611 // Get max value
612 const auto max_ptr = reinterpret_cast<const float *>(_max.ptr());
613 const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
614
615 // Run neon loop
616 for(int i = 0; i < long_steps; ++i)
617 {
618 float32x4_t vec_elements = vld1q_f32(in_ptr);
619 vec_elements = vsubq_f32(vec_elements, vec_max);
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100620 vec_elements = vmulq_n_f32(vec_elements, beta);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100621 vec_elements = vexpq_f32(vec_elements);
622
623 vst1q_f32(exp_ptr, vec_elements);
624 vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
625
626 in_ptr += step;
627 exp_ptr += step;
628 }
629
630 // Reduce sum
631 float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
632 carry_addition = vpadd_f32(carry_addition, carry_addition);
633 float sum = vget_lane_f32(carry_addition, 0);
634
635 // Run remaining elements
636 for(int i = 0; i < small_steps; ++i)
637 {
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100638 float element = std::exp((in_ptr[i] - *max_ptr) * beta);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100639 exp_ptr[i] = element;
640 sum += element;
641 }
642
643 *(reinterpret_cast<float *>(_sum.ptr())) = sum;
644 }
645 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
646}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100647} //namespace
648
649NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100650 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr), _beta(1.0f)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100651{
652}
653
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100654void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100655{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000656 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100657
658 // Output auto initialization if not yet initialized
659 auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
660 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
661
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000662 // Perform validation step
663 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info(), beta));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100664
665 switch(input->info()->data_type())
666 {
667 case DataType::QS8:
668 _func = &logits_1d_shift_exp_sum_qs8;
669 break;
Georgios Pinitas9247c922017-06-28 18:29:47 +0100670 case DataType::QS16:
671 _func = &logits_1d_shift_exp_sum_qs16;
672 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100673 case DataType::F32:
674 _func = &logits_1d_shift_exp_sum_f32;
675 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100676 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000677#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100678 _func = &logits_1d_shift_exp_sum_f16;
679 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000680#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100681 default:
682 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100683 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100684 }
685
686 _input = input;
687 _max = max;
688 _output = output;
689 _sum = sum;
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100690 _beta = beta;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100691
692 // Configure kernel window
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000693 auto win_config = validate_and_configure_window_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info());
694 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
695 INEKernel::configure(win_config.second);
696}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100697
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000698Status NELogits1DShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
699{
700 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_shift_exp_sum(input, max, output, sum, beta));
701 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_shift_exp_sum(input->clone().get(), max->clone().get(), output->clone().get(), sum->clone().get()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100702
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000703 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100704}
705
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100706void NELogits1DShiftExpSumKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100707{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100708 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100709 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
710 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
711 ARM_COMPUTE_ERROR_ON(_func == nullptr);
712
Pablo Palmiera2b89ca2017-10-05 15:01:34 +0100713 (*_func)(_input, _max, _output, _sum, window, _beta);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100714}
715
716namespace
717{
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100718void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
719{
720 Window window_sum(window);
721 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
722 Window sum_slice = window_sum.first_slice_window_1D();
723 Window in_slice = window.first_slice_window_1D();
724
725 const int fixed_point_position = in->info()->fixed_point_position();
726
727 do
728 {
729 Iterator input(in, in_slice);
730 Iterator _sum(sum, sum_slice);
731 Iterator output(out, in_slice);
732
733 const int8_t sum_value = *reinterpret_cast<const qint8_t *>(_sum.ptr());
734 const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
735
736 execute_window_loop(in_slice, [&](const Coordinates & id)
737 {
738 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
739 const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
740
741 const qint8x16_t vec_in = vld1q_qs8(in_ptr);
742 const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
743
744 vst1q_qs8(out_ptr, normalized_value);
745 },
746 input, output);
747 }
748 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
749}
Georgios Pinitas9247c922017-06-28 18:29:47 +0100750void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
751{
752 Window window_sum(window);
753 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
754 Window sum_slice = window_sum.first_slice_window_1D();
755 Window in_slice = window.first_slice_window_1D();
756
757 const int fixed_point_position = in->info()->fixed_point_position();
758
759 do
760 {
761 Iterator input(in, in_slice);
762 Iterator _sum(sum, sum_slice);
763 Iterator output(out, in_slice);
764
765 const int16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
766 const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
767
768 execute_window_loop(in_slice, [&](const Coordinates & id)
769 {
770 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
771 const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
772
773 const qint16x8_t vec_in = vld1q_qs16(in_ptr);
774 const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
775
776 vst1q_qs16(out_ptr, normalized_value);
777 },
778 input, output);
779 }
780 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
781}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000782#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100783void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
784{
785 Window window_sum(window);
786 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
787 Window sum_slice = window_sum.first_slice_window_1D();
788 Window in_slice = window.first_slice_window_1D();
789
790 do
791 {
792 Iterator input(in, in_slice);
793 Iterator _sum(sum, sum_slice);
794 Iterator output(out, in_slice);
795
796 const float16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
797 const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
798
799 execute_window_loop(in_slice, [&](const Coordinates & id)
800 {
801 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
802 const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
803
804 const float16x8_t vec_in = vld1q_f16(in_ptr);
805 const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
806
807 vst1q_f16(out_ptr, normalized_value);
808 },
809 input, output);
810 }
811 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
812}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000813#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellob49a7152017-07-11 16:31:35 +0100814
Georgios Pinitas9247c922017-06-28 18:29:47 +0100815void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
816{
817 Window window_sum(window);
818 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
819 Window sum_slice = window_sum.first_slice_window_1D();
820 Window in_slice = window.first_slice_window_1D();
821
822 do
823 {
824 Iterator input(in, in_slice);
825 Iterator _sum(sum, sum_slice);
826 Iterator output(out, in_slice);
827
828 const float sum_value = *reinterpret_cast<const float *>(_sum.ptr());
829 const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
830
831 execute_window_loop(in_slice, [&](const Coordinates & id)
832 {
833 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
834 const auto out_ptr = reinterpret_cast<float *>(output.ptr());
835
836 const float32x4_t vec_in = vld1q_f32(in_ptr);
837 const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
838
839 vst1q_f32(out_ptr, normalized_value);
840 },
841 input, output);
842 }
843 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
844}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100845} // namespace
846
847NELogits1DNormKernel::NELogits1DNormKernel()
848 : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
849{
850}
851
852void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
853{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000854 ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
Georgios Pinitasd368df32017-07-04 11:06:15 +0100855
856 // Output auto initialization if not yet initialized
857 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
858
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000859 // Perform validation step
860 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_norm(input->info(), sum->info(), output->info()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100861
862 _input = input;
863 _sum = sum;
864 _output = output;
865
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100866 switch(input->info()->data_type())
867 {
868 case DataType::QS8:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100869 _func = &logits_1d_norm_qs8;
870 break;
871 case DataType::QS16:
872 _func = &logits_1d_norm_qs16;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100873 break;
874 case DataType::F32:
Georgios Pinitas9247c922017-06-28 18:29:47 +0100875 _func = &logits_1d_norm_f32;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100876 break;
Pablo Tellob49a7152017-07-11 16:31:35 +0100877 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000878#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellob49a7152017-07-11 16:31:35 +0100879 _func = &logits_1d_norm_f16;
880 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000881#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100882 default:
883 ARM_COMPUTE_ERROR("Unsupported data type.");
Pablo Tellob49a7152017-07-11 16:31:35 +0100884 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100885 }
886
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000887 // Configure kernel window
888 auto win_config = validate_and_configure_window_logits_1d_norm(input->info(), sum->info(), output->info());
889 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
890 INEKernel::configure(win_config.second);
891}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100892
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000893Status NELogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
894{
895 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_norm(input, sum, output));
896 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_norm(input->clone().get(), sum->clone().get(), output->clone().get()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100897
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000898 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100899}
900
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100901void NELogits1DNormKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100902{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100903 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100904 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
905 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
906 ARM_COMPUTE_ERROR_ON(_func == nullptr);
907
908 (*_func)(_input, _sum, _output, window);
909}