blob: 8874b52e192802825660d17a1569c89db717c254 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Manuel Bottini6a2b6e82019-02-25 13:50:11 +00002 * Copyright (c) 2016-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
Manuel Bottini6a2b6e82019-02-25 13:50:11 +000030#include "arm_compute/core/NEON/NEAsymm.h"
Michele Di Giorgio81f0d152017-07-11 15:00:52 +010031#include "arm_compute/core/NEON/NEFixedPoint.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Validate.h"
34
35#include <algorithm>
36#include <arm_neon.h>
37#include <cstdint>
38#include <map>
39#include <string>
40
41using namespace arm_compute;
42
43namespace arm_compute
44{
45class Coordinates;
46} // namespace arm_compute
47
48namespace
49{
Georgios Pinitascbf39c62018-09-10 15:07:45 +010050constexpr unsigned int num_elems_processed_per_iteration = 16;
51
Anthony Barbier6ff3b192017-09-04 18:44:23 +010052void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
53{
Georgios Pinitascbf39c62018-09-10 15:07:45 +010054 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
55 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010056 Iterator output(out, window);
57
Michalis Spyroua4f378d2019-04-26 14:54:54 +010058 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010059 {
60 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
61 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
62
63 vst1q_u8(output.ptr(), vsubq_u8(ta1, ta2));
64 },
65 input1, input2, output);
66}
67
68void sub_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
69{
Georgios Pinitascbf39c62018-09-10 15:07:45 +010070 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
71 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010072 Iterator output(out, window);
73
Michalis Spyroua4f378d2019-04-26 14:54:54 +010074 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010075 {
76 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
77 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
78
79 vst1q_u8(output.ptr(), vqsubq_u8(ta1, ta2));
80 },
81 input1, input2, output);
82}
83
Manuel Bottini6a2b6e82019-02-25 13:50:11 +000084void sub_saturate_QAYSMM8_QAYSMM8_QAYSMM8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
85{
86 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
87 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
88 Iterator output(out, window);
89
Georgios Pinitas4c5469b2019-05-21 13:32:43 +010090 const UniformQuantizationInfo iq1_info = in1->info()->quantization_info().uniform();
91 const UniformQuantizationInfo iq2_info = in2->info()->quantization_info().uniform();
92 const UniformQuantizationInfo oq_info = out->info()->quantization_info().uniform();
93
Michalis Spyroua4f378d2019-04-26 14:54:54 +010094 execute_window_loop(window, [&](const Coordinates &)
Manuel Bottini6a2b6e82019-02-25 13:50:11 +000095 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +010096 const float32x4x4_t ta1 = vdequantize(vld1q_u8(reinterpret_cast<const qasymm8_t *>(input1.ptr())), iq1_info);
97 const float32x4x4_t ta2 = vdequantize(vld1q_u8(reinterpret_cast<const qasymm8_t *>(input2.ptr())), iq2_info);
Manuel Bottini6a2b6e82019-02-25 13:50:11 +000098
99 const float32x4x4_t ta3 =
100 {
101 {
102 vsubq_f32(ta1.val[0], ta2.val[0]),
103 vsubq_f32(ta1.val[1], ta2.val[1]),
104 vsubq_f32(ta1.val[2], ta2.val[2]),
105 vsubq_f32(ta1.val[3], ta2.val[3]),
106 }
107 };
108
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100109 const uint8x16_t result = vquantize(ta3, oq_info);
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000110
111 vst1q_u8(reinterpret_cast<qasymm8_t *>(output.ptr()), result);
112 },
113 input1, input2, output);
114}
115
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100116void sub_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
117{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100118 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
119 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120 Iterator output(out, window);
121
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100122 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123 {
124 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
125 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
126
127 const int16x8x2_t ta3 =
128 {
129 {
130 vsubq_s16(ta1.val[0], ta2.val[0]),
131 vsubq_s16(ta1.val[1], ta2.val[1])
132 }
133 };
134
135 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
136 },
137 input1, input2, output);
138}
139
140void sub_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
141{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100142 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
143 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100144 Iterator output(out, window);
145
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100146 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100147 {
148 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
149 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
150
151 const int16x8x2_t ta3 =
152 {
153 {
154 vqsubq_s16(ta1.val[0], ta2.val[0]),
155 vqsubq_s16(ta1.val[1], ta2.val[1])
156 }
157 };
158
159 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
160 },
161 input1, input2, output);
162}
163
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000164#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellod7a5d222017-07-11 13:54:43 +0100165inline float16x8x2_t vsub2q_f16(const float16x8x2_t &a, const float16x8x2_t &b)
166{
167 const float16x8x2_t res =
168 {
169 {
170 vsubq_f16(a.val[0], b.val[0]),
171 vsubq_f16(a.val[1], b.val[1])
172 }
173 };
174
175 return res;
176}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000177#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellod7a5d222017-07-11 13:54:43 +0100178
179void sub_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
180{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000181#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100182 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
183 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Pablo Tellod7a5d222017-07-11 13:54:43 +0100184 Iterator output(out, window);
185
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100186 execute_window_loop(window, [&](const Coordinates &)
Pablo Tellod7a5d222017-07-11 13:54:43 +0100187 {
188 const float16x8x2_t a = vld2q_f16(reinterpret_cast<const float16_t *>(input1.ptr()));
189 const float16x8x2_t b = vld2q_f16(reinterpret_cast<const float16_t *>(input2.ptr()));
190
191 vst2q_f16(reinterpret_cast<float16_t *>(output.ptr()), vsub2q_f16(a, b));
192 },
193 input1, input2, output);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000194#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellod7a5d222017-07-11 13:54:43 +0100195 ARM_COMPUTE_UNUSED(in1);
196 ARM_COMPUTE_UNUSED(in2);
197 ARM_COMPUTE_UNUSED(out);
198 ARM_COMPUTE_UNUSED(window);
199 ARM_COMPUTE_ERROR("Not supported, recompile the library with arch=arm64-v8.2-a");
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000200#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellod7a5d222017-07-11 13:54:43 +0100201}
202
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100203void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
204{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100205 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
206 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100207 Iterator output(out, window);
208
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100209 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210 {
211 const float32x4x4_t ta1 = vld4q_f32(reinterpret_cast<const float *>(input1.ptr()));
212 const float32x4x4_t ta2 = vld4q_f32(reinterpret_cast<const float *>(input2.ptr()));
213
214 const float32x4x4_t ta3 =
215 {
216 {
217 vsubq_f32(ta1.val[0], ta2.val[0]),
218 vsubq_f32(ta1.val[1], ta2.val[1]),
219 vsubq_f32(ta1.val[2], ta2.val[2]),
220 vsubq_f32(ta1.val[3], ta2.val[3]),
221 }
222 };
223
224 vst4q_f32(reinterpret_cast<float *>(output.ptr()), ta3);
225 },
226 input1, input2, output);
227}
228void sub_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
229{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100230 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
231 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232 Iterator output(out, window);
233
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100234 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100235 {
236 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
237 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
238 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
239
240 a1_0 = vsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
241 a2_0 = vsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
242
243 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
244 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
245 },
246 input1, input2, output);
247}
248
249void sub_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
250{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100251 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
252 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253 Iterator output(out, window);
254
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100255 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100256 {
257 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
258 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
259 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
260
261 a1_0 = vqsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
262 a2_0 = vqsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
263
264 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
265 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
266 },
267 input1, input2, output);
268}
269
270void sub_wrap_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
271{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100272 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
273 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100274 Iterator output(out, window);
275
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100276 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100277 {
278 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
279 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
280 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
281
282 a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
283 a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
284
285 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
286 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
287 },
288 input1, input2, output);
289}
290
291void sub_saturate_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
292{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100293 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
294 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295 Iterator output(out, window);
296
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100297 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100298 {
299 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
300 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
301 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
302
303 a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
304 a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
305
306 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
307 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
308 },
309 input1, input2, output);
310}
311
312void sub_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
313{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100314 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
315 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100316 Iterator output(out, window);
317
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100318 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100319 {
320 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
321 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
322
323 const int16x8_t a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
324 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
325 const int16x8_t a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
326 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
327
328 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
329 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
330 },
331 input1, input2, output);
332}
333
334void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
335{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100336 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
337 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100338 Iterator output(out, window);
339
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100340 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100341 {
342 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
343 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
344
345 const int16x8_t a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
346 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
347 const int16x8_t a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
348 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
349
350 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
351 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
352 },
353 input1, input2, output);
354}
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000355
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100356inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000357{
358 ARM_COMPUTE_UNUSED(policy);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100359 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000360 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
361 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
362 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000363
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100364 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
365 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000366
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000367 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
368 !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8)
369 && !(input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8)
370 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8)
371 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16)
372 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8)
373 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16)
374 && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32)
375 && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16),
376 "You called subtract with the wrong image formats");
377
378 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
379 input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP,
380 "Convert policy cannot be WRAP if datatype is QASYMM8");
381
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100382 // Validate in case of configured output
383 if(output.total_size() > 0)
384 {
385 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
386 !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000387 && !(input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && output.data_type() == DataType::QASYMM8)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100388 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
389 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
390 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
391 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
392 && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
393 && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
394 "You called subtract with the wrong image formats");
395
396 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
397 "Wrong shape for output");
398 }
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000399 return Status{};
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000400}
401
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100402inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000403{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100404 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
405 const TensorShape &out_shape = broadcast_pair.first;
406 const ValidRegion &valid_region = broadcast_pair.second;
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000407
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100408 // Auto initialize output if not initialized
409 {
410 set_shape_if_empty(output, out_shape);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000411
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100412 if(input1.data_type() == DataType::S16 || input2.data_type() == DataType::S16)
413 {
414 set_format_if_unknown(output, Format::S16);
415 }
416 else if(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16)
417 {
418 set_format_if_unknown(output, Format::F16);
419 }
420 else if(input1.data_type() == DataType::F32 || input2.data_type() == DataType::F32)
421 {
422 set_format_if_unknown(output, Format::F32);
423 }
424 }
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000425
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100426 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration));
427 Window win_input1 = win.broadcast_if_dimension_le_one(input1);
428 Window win_input2 = win.broadcast_if_dimension_le_one(input2);
429
430 AccessWindowHorizontal input1_access(&input1, 0, num_elems_processed_per_iteration);
431 AccessWindowHorizontal input2_access(&input2, 0, num_elems_processed_per_iteration);
432 AccessWindowHorizontal output_access(&output, 0, num_elems_processed_per_iteration);
433
434 bool window_changed = update_window_and_padding(win_input1, input1_access)
435 || update_window_and_padding(win_input2, input2_access)
436 || update_window_and_padding(win, output_access);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000437
438 output_access.set_valid_region(win, valid_region);
439
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000440 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000441 return std::make_pair(err, win);
442}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100443} // namespace
444
445NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel()
446 : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
447{
448}
449
450void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
451{
452 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100453 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100454
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100455 // Configure kernel window
456 auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info());
457 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100458
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000459 static std::map<std::string, NEArithmeticSubtractionKernel::SubFunction *> map_function =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100460 {
461 { "sub_wrap_U8_U8_U8", &sub_wrap_U8_U8_U8 },
462 { "sub_wrap_U8_U8_S16", &sub_wrap_U8_U8_S16 },
463 { "sub_saturate_U8_U8_U8", &sub_saturate_U8_U8_U8 },
464 { "sub_saturate_U8_U8_S16", &sub_saturate_U8_U8_S16 },
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000465 { "sub_saturate_QASYMM8_QASYMM8_QASYMM8", &sub_saturate_QAYSMM8_QAYSMM8_QAYSMM8 },
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100466 { "sub_wrap_U8_S16_S16", &sub_wrap_U8_S16_S16 },
467 { "sub_wrap_S16_U8_S16", &sub_wrap_S16_U8_S16 },
468 { "sub_saturate_U8_S16_S16", &sub_saturate_U8_S16_S16 },
469 { "sub_saturate_S16_U8_S16", &sub_saturate_S16_U8_S16 },
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100470 { "sub_wrap_S16_S16_S16", &sub_wrap_S16_S16_S16 },
471 { "sub_saturate_S16_S16_S16", &sub_saturate_S16_S16_S16 },
472 { "sub_wrap_F32_F32_F32", &sub_F32_F32_F32 },
473 { "sub_saturate_F32_F32_F32", &sub_F32_F32_F32 },
Pablo Tellod7a5d222017-07-11 13:54:43 +0100474 { "sub_wrap_F16_F16_F16", &sub_F16_F16_F16 },
475 { "sub_saturate_F16_F16_F16", &sub_F16_F16_F16 },
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100476 };
477
478 _input1 = input1;
479 _input2 = input2;
480 _output = output;
481
482 std::string function_to_call("sub_");
483 function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_";
484 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
485 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
486 function_to_call += string_from_data_type(output->info()->data_type());
487
488 auto it = map_function.find(function_to_call);
489
490 if(it != map_function.end())
491 {
492 _func = it->second;
493 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100494
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000495 INEKernel::configure(win_config.second);
496}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100497
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000498Status NEArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000499{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100500 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
501
502 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy));
503 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*input1->clone(), *input2->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100504
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000505 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100506}
507
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100508void NEArithmeticSubtractionKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100509{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100510 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100511 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
512 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
513 ARM_COMPUTE_ERROR_ON(_func == nullptr);
514
515 (*_func)(_input1, _input2, _output, window);
516}
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100517
518BorderSize NEArithmeticSubtractionKernel::border_size() const
519{
520 const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
521 const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100522 return BorderSize{ 0, border, 0, 0 };
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100523}