blob: 856e3acb354019756d5cac3bbd958bdd752860f3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEAccumulateKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/Types.h"
30#include "arm_compute/core/Validate.h"
31
32#include <arm_neon.h>
33
34using namespace arm_compute;
35
36namespace arm_compute
37{
38class Coordinates;
39} // namespace arm_compute
40
41/* Max S16 value used for saturation purposes. */
42const static uint16x8_t max_int_u16 = vdupq_n_u16(static_cast<uint16_t>(INT16_MAX));
43
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +010044#ifdef ARM_COMPUTE_AARCH64_V8_2
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045namespace fp16
46{
47inline float16x8x2_t convert_u8x16_to_f16x8x2(uint8x16_t input)
48{
49 const float16x8x2_t out =
50 {
51 {
52 vcvtq_f16_u16(vmovl_u8(vget_low_u8(input))),
53 vcvtq_f16_u16(vmovl_u8(vget_high_u8(input)))
54 }
55 };
56
57 return out;
58}
59
60inline uint8x16_t convert_f16x8x2_to_u8x16(const float16x8x2_t &input)
61{
62 return vcombine_u8(vmovn_u16(vcvtq_u16_f16(input.val[0])),
63 vmovn_u16(vcvtq_u16_f16(input.val[1])));
64}
65
66inline float16x8x2_t vector_accumulate_weighted(const float16x8x2_t &vec0, const float16x8x2_t &vec1, float16x8_t scale_val, float16x8_t scale_val2)
67{
68 const float16x8x2_t res =
69 {
70 {
71 vfmaq_f16(vmulq_f16(vec1.val[0], scale_val), vec0.val[0], scale_val2),
72 vfmaq_f16(vmulq_f16(vec1.val[1], scale_val), vec0.val[1], scale_val2)
73 }
74 };
75
76 return res;
77}
78
79void acc_we_v16_u8(const void *__restrict input, void *__restrict accum, float16x8_t scale_val, float16x8_t scale_val2)
80{
81 ARM_COMPUTE_ERROR_ON(nullptr == input);
82 ARM_COMPUTE_ERROR_ON(nullptr == accum);
83
84 const auto input_ptr = static_cast<const uint8_t *__restrict>(input);
85 const auto accum_ptr = static_cast<uint8_t *__restrict>(accum);
86
87 const uint8x16x4_t input_buffer = vld4q_u8(input_ptr);
88 uint8x16x4_t accum_buffer = vld4q_u8(accum_ptr);
89
90 const float16x8x2_t f16_input_0 = convert_u8x16_to_f16x8x2(input_buffer.val[0]);
91 const float16x8x2_t f16_input_1 = convert_u8x16_to_f16x8x2(input_buffer.val[1]);
92 const float16x8x2_t f16_input_2 = convert_u8x16_to_f16x8x2(input_buffer.val[2]);
93 const float16x8x2_t f16_input_3 = convert_u8x16_to_f16x8x2(input_buffer.val[3]);
94
95 float16x8x2_t f16_accum_0 = convert_u8x16_to_f16x8x2(accum_buffer.val[0]);
96 float16x8x2_t f16_accum_1 = convert_u8x16_to_f16x8x2(accum_buffer.val[1]);
97 float16x8x2_t f16_accum_2 = convert_u8x16_to_f16x8x2(accum_buffer.val[2]);
98 float16x8x2_t f16_accum_3 = convert_u8x16_to_f16x8x2(accum_buffer.val[3]);
99
100 f16_accum_0 = vector_accumulate_weighted(f16_input_0, f16_accum_0, scale_val, scale_val2);
101 f16_accum_1 = vector_accumulate_weighted(f16_input_1, f16_accum_1, scale_val, scale_val2);
102 f16_accum_2 = vector_accumulate_weighted(f16_input_2, f16_accum_2, scale_val, scale_val2);
103 f16_accum_3 = vector_accumulate_weighted(f16_input_3, f16_accum_3, scale_val, scale_val2);
104
105 accum_buffer = { {
106 convert_f16x8x2_to_u8x16(f16_accum_0),
107 convert_f16x8x2_to_u8x16(f16_accum_1),
108 convert_f16x8x2_to_u8x16(f16_accum_2),
109 convert_f16x8x2_to_u8x16(f16_accum_3)
110 }
111 };
112
113 vst4q_u8(accum_ptr, accum_buffer);
114}
115} // namespace fp16
116
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100117void NEAccumulateWeightedFP16Kernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100118{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100119 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
121 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INESimpleKernel::window(), window);
122
123 Iterator input(_input, window);
124 Iterator accum(_output, window);
125
126 const float16x8_t scale_val = vdupq_n_f16(1.f - _alpha);
127 const float16x8_t scale_val2 = vdupq_n_f16(_alpha);
128
129 execute_window_loop(window, [&](const Coordinates & id)
130 {
131 fp16::acc_we_v16_u8(input.ptr(), accum.ptr(), scale_val, scale_val2);
132 },
133 input, accum);
134}
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +0100135#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100136
137namespace
138{
139inline void acc_v16_u8(const void *__restrict input, void *__restrict accum)
140{
141 ARM_COMPUTE_ERROR_ON(nullptr == input);
142 ARM_COMPUTE_ERROR_ON(nullptr == accum);
143
144 const auto in = static_cast<const uint8_t *__restrict>(input);
145 const auto out = static_cast<int16_t *__restrict>(accum);
146
147 uint8x16_t ta1 = vld1q_u8(in);
148 int16x8_t ta2 = vld1q_s16(out);
149 int16x8_t ta3 = vld1q_s16(out + 8);
150
151 ta2 = vqaddq_s16(ta2, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(ta1))));
152 ta3 = vqaddq_s16(ta3, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(ta1))));
153
154 vst1q_s16(out, ta2);
155 vst1q_s16(out + 8, ta3);
156}
157
158inline float32x4x4_t convert_u8x16_to_f32x4x4(uint8x16_t input)
159{
160 const uint16x8_t u16_output_low = vmovl_u8(vget_low_u8(input));
161 const uint16x8_t u16_output_hi = vmovl_u8(vget_high_u8(input));
162
163 const float32x4x4_t res =
164 {
165 {
166 vcvtq_f32_u32(vmovl_u16(vget_low_u16(u16_output_low))),
167 vcvtq_f32_u32(vmovl_u16(vget_high_u16(u16_output_low))),
168 vcvtq_f32_u32(vmovl_u16(vget_low_u16(u16_output_hi))),
169 vcvtq_f32_u32(vmovl_u16(vget_high_u16(u16_output_hi)))
170 }
171 };
172
173 return res;
174}
175
176inline uint8x16_t convert_f32x4x4_to_u8x16(const float32x4x4_t &input)
177{
178 return vcombine_u8(vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(input.val[0])),
179 vmovn_u32(vcvtq_u32_f32(input.val[1])))),
180 vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(input.val[2])),
181 vmovn_u32(vcvtq_u32_f32(input.val[3])))));
182}
183
184inline float32x4x4_t vector_accumulate_weighted(const float32x4x4_t &vector_input, float32x4x4_t vector_output, float32x4_t scale_val, float32x4_t scale_val2)
185{
186 vector_output.val[0] = vmulq_f32(vector_output.val[0], scale_val);
187 vector_output.val[1] = vmulq_f32(vector_output.val[1], scale_val);
188 vector_output.val[2] = vmulq_f32(vector_output.val[2], scale_val);
189 vector_output.val[3] = vmulq_f32(vector_output.val[3], scale_val);
190
191 vector_output.val[0] = vmlaq_f32(vector_output.val[0], vector_input.val[0], scale_val2);
192 vector_output.val[1] = vmlaq_f32(vector_output.val[1], vector_input.val[1], scale_val2);
193 vector_output.val[2] = vmlaq_f32(vector_output.val[2], vector_input.val[2], scale_val2);
194 vector_output.val[3] = vmlaq_f32(vector_output.val[3], vector_input.val[3], scale_val2);
195
196 return vector_output;
197}
198
199inline void acc_we_v16_u8(const void *__restrict input, void *__restrict accum, const float32x4_t scale_val, const float32x4_t scale_val2)
200{
201 ARM_COMPUTE_ERROR_ON(nullptr == input);
202 ARM_COMPUTE_ERROR_ON(nullptr == accum);
203
204 const auto input_ptr = static_cast<const uint8_t *__restrict>(input);
205 const auto accum_ptr = static_cast<uint8_t *__restrict>(accum);
206
207 const uint8x16_t input_buffer = vld1q_u8(input_ptr);
208 const uint8x16_t accum_buffer = vld1q_u8(accum_ptr);
209
210 const float32x4x4_t f32_input_0 = convert_u8x16_to_f32x4x4(input_buffer);
211 const float32x4x4_t f32_output_0 = convert_u8x16_to_f32x4x4(accum_buffer);
212
213 const float32x4x4_t f32_res_0 = vector_accumulate_weighted(f32_input_0, f32_output_0, scale_val, scale_val2);
214
215 vst1q_u8(accum_ptr, convert_f32x4x4_to_u8x16(f32_res_0));
216}
217
218void acc_sq_v16_u8(const void *__restrict input, uint32_t shift, void *__restrict accum)
219{
220 ARM_COMPUTE_ERROR_ON(nullptr == input);
221 ARM_COMPUTE_ERROR_ON(nullptr == accum);
222 ARM_COMPUTE_ERROR_ON(shift > 15);
223
224 const auto input_buffer = static_cast<const uint8_t *__restrict>(input);
225 const auto accum_buffer = static_cast<int16_t *__restrict>(accum);
226
227 const uint8x16_t ta1 = vld1q_u8(input_buffer);
228 uint16x8_t ta2 = vreinterpretq_u16_s16(vld1q_s16(accum_buffer));
229 uint16x8_t ta3 = vreinterpretq_u16_s16(vld1q_s16(accum_buffer + 8));
230
231 const int16x8_t vector_shift = vdupq_n_s16(-static_cast<int16_t>(shift));
232
233 uint16x8_t linput = vmovl_u8(vget_low_u8(ta1));
234 uint16x8_t hinput = vmovl_u8(vget_high_u8(ta1));
235
236 linput = vmulq_u16(linput, linput);
237 hinput = vmulq_u16(hinput, hinput);
238
239 linput = vqshlq_u16(linput, vector_shift);
240 hinput = vqshlq_u16(hinput, vector_shift);
241
242 ta2 = vqaddq_u16(ta2, linput);
243 ta3 = vqaddq_u16(ta3, hinput);
244
245 vst1q_s16(accum_buffer, vreinterpretq_s16_u16(vminq_u16(max_int_u16, ta2)));
246 vst1q_s16(accum_buffer + 8, vreinterpretq_s16_u16(vminq_u16(max_int_u16, ta3)));
247}
248} // namespace
249
250void NEAccumulateKernel::configure(const ITensor *input, ITensor *accum)
251{
252 ARM_COMPUTE_ERROR_ON_NULLPTR(input, accum);
253
254 set_shape_if_empty(*accum->info(), input->info()->tensor_shape());
255
256 set_format_if_unknown(*accum->info(), Format::S16);
257
258 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
259 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::S16);
260 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, accum);
261
262 constexpr unsigned int num_elems_processed_per_iteration = 16;
263 INESimpleKernel::configure(input, accum, num_elems_processed_per_iteration);
264}
265
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100266void NEAccumulateKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100268 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100269 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
270 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INESimpleKernel::window(), window);
271 Iterator input(_input, window);
272 Iterator accum(_output, window);
273
274 execute_window_loop(window, [&](const Coordinates & id)
275 {
276 acc_v16_u8(input.ptr(), accum.ptr());
277 },
278 input, accum);
279}
280
281NEAccumulateWeightedKernel::NEAccumulateWeightedKernel()
282 : _alpha(0.0f)
283{
284}
285
286void NEAccumulateWeightedKernel::configure(const ITensor *input, float alpha, ITensor *accum)
287{
288 ARM_COMPUTE_ERROR_ON_NULLPTR(input, accum);
289
290 set_shape_if_empty(*accum->info(), input->info()->tensor_shape());
291
292 set_format_if_unknown(*accum->info(), Format::U8);
293
294 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, accum);
295 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
296 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::U8);
297 ARM_COMPUTE_ERROR_ON(alpha < 0.0 || alpha > 1.0);
298
299 _alpha = alpha;
300
301 constexpr unsigned int num_elems_processed_per_iteration = 16;
302 INESimpleKernel::configure(input, accum, num_elems_processed_per_iteration);
303}
304
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100305void NEAccumulateWeightedKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100306{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100307 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100308 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
309 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INESimpleKernel::window(), window);
310
311 Iterator input(_input, window);
312 Iterator accum(_output, window);
313
314 const float32x4_t scale_val = vdupq_n_f32(1.f - _alpha);
315 const float32x4_t scale_val2 = vdupq_n_f32(_alpha);
316
317 execute_window_loop(window, [&](const Coordinates & id)
318 {
319 acc_we_v16_u8(input.ptr(), accum.ptr(), scale_val, scale_val2);
320 },
321 input, accum);
322}
323
324NEAccumulateSquaredKernel::NEAccumulateSquaredKernel()
325 : _shift(0)
326{
327}
328
329void NEAccumulateSquaredKernel::configure(const ITensor *input, uint32_t shift, ITensor *accum)
330{
331 ARM_COMPUTE_ERROR_ON_NULLPTR(input, accum);
332
333 set_shape_if_empty(*accum->info(), input->info()->tensor_shape());
334
335 set_format_if_unknown(*accum->info(), Format::S16);
336
337 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, accum);
338 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
339 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::S16);
340 ARM_COMPUTE_ERROR_ON(shift > 15);
341
342 _shift = shift;
343
344 constexpr unsigned int num_elems_processed_per_iteration = 16;
345 INESimpleKernel::configure(input, accum, num_elems_processed_per_iteration);
346}
347
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100348void NEAccumulateSquaredKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100350 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100351 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
352 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INESimpleKernel::window(), window);
353 Iterator input(_input, window);
354 Iterator accum(_output, window);
355
356 execute_window_loop(window, [&](const Coordinates & id)
357 {
358 acc_sq_v16_u8(input.ptr(), _shift, accum.ptr());
359 },
360 input, accum);
361}