blob: d3e62b069e858c54316125d187919aa9fdb40aab [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Validate.h"
31
32#include <algorithm>
33#include <arm_neon.h>
34#include <cstdint>
35#include <map>
36#include <string>
37
38using namespace arm_compute;
39
40namespace arm_compute
41{
42class Coordinates;
43} // namespace arm_compute
44
45namespace
46{
47void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
48{
49 Iterator input1(in1, window);
50 Iterator input2(in2, window);
51 Iterator output(out, window);
52
53 execute_window_loop(window, [&](const Coordinates & id)
54 {
55 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
56 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
57
58 vst1q_u8(output.ptr(), vsubq_u8(ta1, ta2));
59 },
60 input1, input2, output);
61}
62
63void sub_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
64{
65 Iterator input1(in1, window);
66 Iterator input2(in2, window);
67 Iterator output(out, window);
68
69 execute_window_loop(window, [&](const Coordinates & id)
70 {
71 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
72 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
73
74 vst1q_u8(output.ptr(), vqsubq_u8(ta1, ta2));
75 },
76 input1, input2, output);
77}
78
79void sub_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
80{
81 Iterator input1(in1, window);
82 Iterator input2(in2, window);
83 Iterator output(out, window);
84
85 execute_window_loop(window, [&](const Coordinates & id)
86 {
87 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
88 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
89
90 const int16x8x2_t ta3 =
91 {
92 {
93 vsubq_s16(ta1.val[0], ta2.val[0]),
94 vsubq_s16(ta1.val[1], ta2.val[1])
95 }
96 };
97
98 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
99 },
100 input1, input2, output);
101}
102
103void sub_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
104{
105 Iterator input1(in1, window);
106 Iterator input2(in2, window);
107 Iterator output(out, window);
108
109 execute_window_loop(window, [&](const Coordinates & id)
110 {
111 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
112 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
113
114 const int16x8x2_t ta3 =
115 {
116 {
117 vqsubq_s16(ta1.val[0], ta2.val[0]),
118 vqsubq_s16(ta1.val[1], ta2.val[1])
119 }
120 };
121
122 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
123 },
124 input1, input2, output);
125}
126
127void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
128{
129 Iterator input1(in1, window);
130 Iterator input2(in2, window);
131 Iterator output(out, window);
132
133 execute_window_loop(window, [&](const Coordinates & id)
134 {
135 const float32x4x4_t ta1 = vld4q_f32(reinterpret_cast<const float *>(input1.ptr()));
136 const float32x4x4_t ta2 = vld4q_f32(reinterpret_cast<const float *>(input2.ptr()));
137
138 const float32x4x4_t ta3 =
139 {
140 {
141 vsubq_f32(ta1.val[0], ta2.val[0]),
142 vsubq_f32(ta1.val[1], ta2.val[1]),
143 vsubq_f32(ta1.val[2], ta2.val[2]),
144 vsubq_f32(ta1.val[3], ta2.val[3]),
145 }
146 };
147
148 vst4q_f32(reinterpret_cast<float *>(output.ptr()), ta3);
149 },
150 input1, input2, output);
151}
152void sub_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
153{
154 Iterator input1(in1, window);
155 Iterator input2(in2, window);
156 Iterator output(out, window);
157
158 execute_window_loop(window, [&](const Coordinates & id)
159 {
160 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
161 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
162 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
163
164 a1_0 = vsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
165 a2_0 = vsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
166
167 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
168 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
169 },
170 input1, input2, output);
171}
172
173void sub_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
174{
175 Iterator input1(in1, window);
176 Iterator input2(in2, window);
177 Iterator output(out, window);
178
179 execute_window_loop(window, [&](const Coordinates & id)
180 {
181 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
182 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
183 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
184
185 a1_0 = vqsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
186 a2_0 = vqsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
187
188 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
189 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
190 },
191 input1, input2, output);
192}
193
194void sub_wrap_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
195{
196 Iterator input1(in1, window);
197 Iterator input2(in2, window);
198 Iterator output(out, window);
199
200 execute_window_loop(window, [&](const Coordinates & id)
201 {
202 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
203 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
204 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
205
206 a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
207 a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
208
209 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
210 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
211 },
212 input1, input2, output);
213}
214
215void sub_saturate_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
216{
217 Iterator input1(in1, window);
218 Iterator input2(in2, window);
219 Iterator output(out, window);
220
221 execute_window_loop(window, [&](const Coordinates & id)
222 {
223 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
224 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
225 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
226
227 a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
228 a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
229
230 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
231 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
232 },
233 input1, input2, output);
234}
235
236void sub_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
237{
238 Iterator input1(in1, window);
239 Iterator input2(in2, window);
240 Iterator output(out, window);
241
242 execute_window_loop(window, [&](const Coordinates & id)
243 {
244 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
245 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
246
247 const int16x8_t a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
248 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
249 const int16x8_t a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
250 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
251
252 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
253 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
254 },
255 input1, input2, output);
256}
257
258void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
259{
260 Iterator input1(in1, window);
261 Iterator input2(in2, window);
262 Iterator output(out, window);
263
264 execute_window_loop(window, [&](const Coordinates & id)
265 {
266 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
267 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
268
269 const int16x8_t a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
270 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
271 const int16x8_t a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
272 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
273
274 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
275 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
276 },
277 input1, input2, output);
278}
279} // namespace
280
281NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel()
282 : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
283{
284}
285
286void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
287{
288 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
289
290 set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
291
292 if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
293 {
294 set_format_if_unknown(*output->info(), Format::S16);
295 }
296 else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
297 {
298 set_format_if_unknown(*output->info(), Format::F32);
299 }
300
301 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
302 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F32);
303 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F32);
304 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F32);
305 ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
306 "Output can only be U8 if both inputs are U8");
307
308 static std::map<std::string, SubFunction *> map_function =
309 {
310 { "sub_wrap_U8_U8_U8", &sub_wrap_U8_U8_U8 },
311 { "sub_wrap_U8_U8_S16", &sub_wrap_U8_U8_S16 },
312 { "sub_saturate_U8_U8_U8", &sub_saturate_U8_U8_U8 },
313 { "sub_saturate_U8_U8_S16", &sub_saturate_U8_U8_S16 },
314 { "sub_wrap_U8_S16_S16", &sub_wrap_U8_S16_S16 },
315 { "sub_wrap_S16_U8_S16", &sub_wrap_S16_U8_S16 },
316 { "sub_saturate_U8_S16_S16", &sub_saturate_U8_S16_S16 },
317 { "sub_saturate_S16_U8_S16", &sub_saturate_S16_U8_S16 },
318 { "sub_wrap_S16_S16_S16", &sub_wrap_S16_S16_S16 },
319 { "sub_saturate_S16_S16_S16", &sub_saturate_S16_S16_S16 },
320 { "sub_wrap_F32_F32_F32", &sub_F32_F32_F32 },
321 { "sub_saturate_F32_F32_F32", &sub_F32_F32_F32 },
322 };
323
324 _input1 = input1;
325 _input2 = input2;
326 _output = output;
327
328 std::string function_to_call("sub_");
329 function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_";
330 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
331 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
332 function_to_call += string_from_data_type(output->info()->data_type());
333
334 auto it = map_function.find(function_to_call);
335
336 if(it != map_function.end())
337 {
338 _func = it->second;
339 }
340 else
341 {
342 ARM_COMPUTE_ERROR("You called subtract with the wrong image formats");
343 }
344
345 constexpr unsigned int num_elems_processed_per_iteration = 16;
346
347 // Configure kernel window
348 Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
349 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
350
351 update_window_and_padding(win,
352 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
353 AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
354 output_access);
355
356 ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
357 input2->info()->valid_region());
358
359 output_access.set_valid_region(win, valid_region);
360
361 INEKernel::configure(win);
362}
363
364void NEArithmeticSubtractionKernel::run(const Window &window)
365{
366 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
367 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
368 ARM_COMPUTE_ERROR_ON(_func == nullptr);
369
370 (*_func)(_input1, _input2, _output, window);
371}