blob: 60b8006d216b80b4598089975c27a0f0d166edb7 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Validate.h"
32
33#include <algorithm>
34#include <arm_neon.h>
35#include <cstdint>
36#include <map>
37#include <string>
38
39using namespace arm_compute;
40
41namespace arm_compute
42{
43class Coordinates;
44} // namespace arm_compute
45
46namespace
47{
48void add_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
49{
50 Iterator input1(in1, window);
51 Iterator input2(in2, window);
52 Iterator output(out, window);
53
54 execute_window_loop(window, [&](const Coordinates & id)
55 {
56 vst1q_u8(output.ptr(), vaddq_u8(vld1q_u8(input1.ptr()), vld1q_u8(input2.ptr())));
57 },
58 input1, input2, output);
59}
60
61void add_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
62{
63 Iterator input1(in1, window);
64 Iterator input2(in2, window);
65 Iterator output(out, window);
66
67 execute_window_loop(window, [&](const Coordinates & id)
68 {
69 vst1q_u8(output.ptr(), vqaddq_u8(vld1q_u8(input1.ptr()), vld1q_u8(input2.ptr())));
70 },
71 input1, input2, output);
72}
73
74inline int16x8x2_t vadd2q_s16(const int16x8x2_t &a, const int16x8x2_t &b)
75{
76 const int16x8x2_t res =
77 {
78 {
79 vaddq_s16(a.val[0], b.val[0]),
80 vaddq_s16(a.val[1], b.val[1])
81 }
82 };
83
84 return res;
85}
86
87inline float32x4x4_t vadd4q_f32(const float32x4x4_t &a, const float32x4x4_t &b)
88{
89 const float32x4x4_t res =
90 {
91 {
92 vaddq_f32(a.val[0], b.val[0]),
93 vaddq_f32(a.val[1], b.val[1]),
94 vaddq_f32(a.val[2], b.val[2]),
95 vaddq_f32(a.val[3], b.val[3])
96 }
97 };
98
99 return res;
100}
101
102inline int16x8x2_t vqadd2q_s16(const int16x8x2_t &a, const int16x8x2_t &b)
103{
104 const int16x8x2_t res =
105 {
106 {
107 vqaddq_s16(a.val[0], b.val[0]),
108 vqaddq_s16(a.val[1], b.val[1])
109 }
110 };
111
112 return res;
113}
114
115void add_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
116{
117 Iterator input1(in1, window);
118 Iterator input2(in2, window);
119 Iterator output(out, window);
120
121 execute_window_loop(window, [&](const Coordinates & id)
122 {
123 const float32x4x4_t a = vld4q_f32(reinterpret_cast<const float *>(input1.ptr()));
124 const float32x4x4_t b = vld4q_f32(reinterpret_cast<const float *>(input2.ptr()));
125
126 vst4q_f32(reinterpret_cast<float *>(output.ptr()), vadd4q_f32(a, b));
127 },
128 input1, input2, output);
129}
130
131void add_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
132{
133 Iterator input1(in1, window);
134 Iterator input2(in2, window);
135 Iterator output(out, window);
136
137 execute_window_loop(window, [&](const Coordinates & id)
138 {
139 const int16x8x2_t a = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
140 const int16x8x2_t b = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
141
142 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), vadd2q_s16(a, b));
143 },
144 input1, input2, output);
145}
146
147void add_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
148{
149 Iterator input1(in1, window);
150 Iterator input2(in2, window);
151 Iterator output(out, window);
152
153 execute_window_loop(window, [&](const Coordinates & id)
154 {
155 const int16x8x2_t a = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
156 const int16x8x2_t b = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
157
158 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), vqadd2q_s16(a, b));
159 },
160 input1, input2, output);
161}
162
163void add_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
164{
165 Iterator input1(in1, window);
166 Iterator input2(in2, window);
167 Iterator output(out, window);
168
169 execute_window_loop(window, [&](const Coordinates & id)
170 {
171 const int16x8x2_t a =
172 {
173 {
174 vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr())),
175 vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8)
176 }
177 };
178 const uint8x16_t b = vld1q_u8(input2.ptr());
179
180 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), vaddq_s16(a.val[0], vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(b)))));
181 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, vaddq_s16(a.val[1], vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(b)))));
182 },
183 input1, input2, output);
184}
185
186void add_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
187{
188 Iterator input1(in1, window);
189 Iterator input2(in2, window);
190 Iterator output(out, window);
191
192 execute_window_loop(window, [&](const Coordinates & id)
193 {
194 const int16x8x2_t a =
195 {
196 {
197 vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr())),
198 vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8)
199 }
200 };
201 const uint8x16_t b = vld1q_u8(input2.ptr());
202
203 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), vqaddq_s16(a.val[0], vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(b)))));
204 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, vqaddq_s16(a.val[1], vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(b)))));
205 },
206 input1, input2, output);
207}
208
209inline void add_wrap_U8_S16_S16(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)
210{
211 //Simply swap the two input buffers:
212 add_wrap_S16_U8_S16(input2, input1, output, window);
213}
214
215inline void add_saturate_U8_S16_S16(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)
216{
217 //Simply swap the two input buffers:
218 add_saturate_S16_U8_S16(input2, input1, output, window);
219}
220
221void add_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
222{
223 Iterator input1(in1, window);
224 Iterator input2(in2, window);
225 Iterator output(out, window);
226
227 execute_window_loop(window, [&](const Coordinates & id)
228 {
229 const uint8x16_t a = vld1q_u8(input1.ptr());
230 const uint8x16_t b = vld1q_u8(input2.ptr());
231
232 const int16x8x2_t a_s16 =
233 {
234 {
235 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(a))),
236 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(a)))
237 }
238 };
239
240 const int16x8x2_t b_s16 =
241 {
242 {
243 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(b))),
244 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(b)))
245 }
246 };
247
248 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), vaddq_s16(a_s16.val[0], b_s16.val[0]));
249 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, vaddq_s16(a_s16.val[1], b_s16.val[1]));
250 },
251 input1, input2, output);
252}
253
254void add_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
255{
256 Iterator input1(in1, window);
257 Iterator input2(in2, window);
258 Iterator output(out, window);
259
260 execute_window_loop(window, [&](const Coordinates & id)
261 {
262 const uint8x16_t a = vld1q_u8(input1.ptr());
263 const uint8x16_t b = vld1q_u8(input2.ptr());
264
265 const int16x8x2_t a_s16 =
266 {
267 {
268 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(a))),
269 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(a)))
270 }
271 };
272
273 const int16x8x2_t b_s16 =
274 {
275 {
276 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(b))),
277 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(b)))
278 }
279 };
280
281 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), vqaddq_s16(a_s16.val[0], b_s16.val[0]));
282 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, vqaddq_s16(a_s16.val[1], b_s16.val[1]));
283 },
284 input1, input2, output);
285}
286} // namespace
287
288NEArithmeticAdditionKernel::NEArithmeticAdditionKernel()
289 : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
290{
291}
292
293void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
294{
295 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
296
Georgios Pinitasf0dea702017-07-03 18:17:28 +0100297 // Auto initialize output if not initialized
298 {
299 set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100300
Georgios Pinitasf0dea702017-07-03 18:17:28 +0100301 if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
302 {
303 set_format_if_unknown(*output->info(), Format::S16);
304 }
305 else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
306 {
307 set_format_if_unknown(*output->info(), Format::F32);
308 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100309 }
310
311 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
312 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F32);
313 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F32);
314 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F32);
315 ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
316 "Output can only be U8 if both inputs are U8");
317
318 static std::map<std::string, AddFunction *> map_function =
319 {
320 { "add_wrap_U8_U8_U8", &add_wrap_U8_U8_U8 },
321 { "add_saturate_U8_U8_U8", &add_saturate_U8_U8_U8 },
322 { "add_wrap_S16_U8_S16", &add_wrap_S16_U8_S16 },
323 { "add_saturate_S16_U8_S16", &add_saturate_S16_U8_S16 },
324 { "add_wrap_U8_S16_S16", &add_wrap_U8_S16_S16 },
325 { "add_saturate_U8_S16_S16", &add_saturate_U8_S16_S16 },
326 { "add_wrap_U8_U8_S16", &add_wrap_U8_U8_S16 },
327 { "add_saturate_U8_U8_S16", &add_saturate_U8_U8_S16 },
328 { "add_wrap_S16_S16_S16", &add_wrap_S16_S16_S16 },
329 { "add_saturate_S16_S16_S16", &add_saturate_S16_S16_S16 },
330 { "add_wrap_F32_F32_F32", &add_F32_F32_F32 },
331 { "add_saturate_F32_F32_F32", &add_F32_F32_F32 },
332 };
333
334 _input1 = input1;
335 _input2 = input2;
336 _output = output;
337
338 std::string function_to_call("add_");
339 function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_";
340 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
341 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
342 function_to_call += string_from_data_type(output->info()->data_type());
343
344 auto it = map_function.find(function_to_call);
345
346 if(it != map_function.end())
347 {
348 _func = it->second;
349 }
350 else
351 {
352 ARM_COMPUTE_ERROR("You called arithmetic addition with the wrong tensor data type");
353 }
354
355 constexpr unsigned int num_elems_processed_per_iteration = 16;
356
357 // Configure kernel window
358 Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
359 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
360
361 update_window_and_padding(win,
362 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
363 AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
364 output_access);
365
366 ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
367 input2->info()->valid_region());
368
369 output_access.set_valid_region(win, valid_region);
370
371 INEKernel::configure(win);
372}
373
374void NEArithmeticAdditionKernel::run(const Window &window)
375{
376 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
377 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
378 ARM_COMPUTE_ERROR_ON(_func == nullptr);
379
380 (*_func)(_input1, _input2, _output, window);
381}