blob: 8bfb37ea183bd704b43ec2abe0a9c0a1c1362c9b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +00002 * Copyright (c) 2016-2020 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Manuel Bottini6a2b6e82019-02-25 13:50:11 +000027#include "arm_compute/core/NEON/NEAsymm.h"
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +000028#include "arm_compute/core/NEON/NESymm.h"
Michalis Spyrou5f390912020-05-13 00:12:08 +010029#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Validate.h"
32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033namespace arm_compute
34{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035namespace
36{
Michalis Spyrou5f390912020-05-13 00:12:08 +010037template <typename T>
38inline typename std::enable_if<std::is_same<T, int8_t>::value, int8_t>::type
39quantize(float val, const QuantizationInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040{
Michalis Spyrou5f390912020-05-13 00:12:08 +010041 return quantize_qasymm8_signed(val, info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042}
43
Michalis Spyrou5f390912020-05-13 00:12:08 +010044template <typename T>
45inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8_t>::type
46quantize(float val, const QuantizationInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047{
Michalis Spyrou5f390912020-05-13 00:12:08 +010048 return quantize_qasymm8(val, info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010049}
50
Michalis Spyrou5f390912020-05-13 00:12:08 +010051template <typename T>
52void sub_same(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, bool is_sat)
Manuel Bottini6a2b6e82019-02-25 13:50:11 +000053{
Michalis Spyrou5f390912020-05-13 00:12:08 +010054 /** NEON vector tag type. */
55 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
56
57 // Create input windows
58 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
59 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
60
61 // Clear X Dimension on execution window as we handle manually
62 Window win = window;
63 win.set(Window::DimX, Window::Dimension(0, 1, 1));
64
65 constexpr int window_step_x = 16 / sizeof(T);
66 const auto window_start_x = static_cast<int>(window.x().start());
67 const auto window_end_x = static_cast<int>(window.x().end());
68 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
69
Manuel Bottini6a2b6e82019-02-25 13:50:11 +000070 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
71 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
72 Iterator output(out, window);
73
Michalis Spyrou5f390912020-05-13 00:12:08 +010074 if(is_broadcast_across_x)
75 {
76 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
77 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
78 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
79 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
80 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
81
82 // Clear X Dimension on execution window as we handle manually
83 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
84
85 Iterator broadcast_input(broadcast_tensor, broadcast_win);
86 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
87 Iterator output(out, win);
88
89 execute_window_loop(win, [&](const Coordinates &)
90 {
91 const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
92 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
93
94 const T broadcast_value = *reinterpret_cast<const T *>(broadcast_input.ptr());
95 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
96
97 // Compute S elements per iteration
98 int x = window_start_x;
99 for(; x <= (window_end_x - window_step_x); x += window_step_x)
100 {
101 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
102 auto res = is_sat ? wrapper::vqsub(broadcast_value_vec, non_broadcast_v) : wrapper::vsub(broadcast_value_vec, non_broadcast_v);
103 if(is_broadcast_input_2)
104 {
105 res = wrapper::vmul(res, wrapper::vdup_n(static_cast<T>(-1), ExactTagType{}));
106 }
107 wrapper::vstore(output_ptr + x, res);
108 }
109
110 // Compute left-over elements
111 for(; x < window_end_x; ++x)
112 {
113 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
114 auto res = is_sat ? wrapper::sub_sat(broadcast_value, non_broadcast_v) : broadcast_value - non_broadcast_v;
115 if(is_broadcast_input_2)
116 {
117 res = static_cast<T>(-1) * res;
118 }
119
120 *(output_ptr + x) = res;
121 }
122 },
123 broadcast_input, non_broadcast_input, output);
124 }
125 else
126 {
127 // Clear X Dimension on execution window as we handle manually
128 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
129 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
130
131 Iterator input1(in1, input1_win);
132 Iterator input2(in2, input2_win);
133 Iterator output(out, win);
134
135 execute_window_loop(win, [&](const Coordinates &)
136 {
137 const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
138 const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
139 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
140
141 // Compute S elements per iteration
142 int x = window_start_x;
143 for(; x <= (window_end_x - window_step_x); x += window_step_x)
144 {
145 const auto val1 = wrapper::vloadq(input1_ptr + x);
146 const auto val2 = wrapper::vloadq(input2_ptr + x);
147 const auto res = is_sat ? wrapper::vqsub(val1, val2) : wrapper::vsub(val1, val2);
148 wrapper::vstore(output_ptr + x, res);
149 }
150
151 // Compute left-over elements
152 for(; x < window_end_x; ++x)
153 {
154 const auto val1 = *(input1_ptr + x);
155 const auto val2 = *(input2_ptr + x);
156 *(output_ptr + x) = is_sat ? wrapper::sub_sat(val1, val2) : val1 - val2;
157 }
158 },
159 input1, input2, output);
160 }
161}
162
163template <typename T>
164void sub_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, bool is_sat)
165{
166 ARM_COMPUTE_UNUSED(is_sat);
167
168 // Create input windows
169 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
170 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
171
172 // Clear X Dimension on execution window as we handle manually
173 Window win = window;
174 win.set(Window::DimX, Window::Dimension(0, 1, 1));
175
176 const int window_step_x = 16;
177 const auto window_start_x = static_cast<int>(window.x().start());
178 const auto window_end_x = static_cast<int>(window.x().end());
179 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
180
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100181 const UniformQuantizationInfo iq1_info = in1->info()->quantization_info().uniform();
182 const UniformQuantizationInfo iq2_info = in2->info()->quantization_info().uniform();
183 const UniformQuantizationInfo oq_info = out->info()->quantization_info().uniform();
184
Michalis Spyrou5f390912020-05-13 00:12:08 +0100185 const float32x4_t vscale1 = vdupq_n_f32(iq1_info.scale);
186 const float32x4_t vscale2 = vdupq_n_f32(iq2_info.scale);
187 const float32x4_t invvscaleo = vdupq_n_f32(1.f / oq_info.scale);
188 const int32x4_t voffset1 = vdupq_n_s32(iq1_info.offset);
189 const int32x4_t voffset2 = vdupq_n_s32(iq2_info.offset);
190 const float32x4_t voffseto = vdupq_n_f32(oq_info.offset);
191
192 if(is_broadcast_across_x)
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000193 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100194 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
195 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
196 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
197 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
198 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
199 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
200 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000201
Michalis Spyrou5f390912020-05-13 00:12:08 +0100202 // Clear X Dimension on execution window as we handle manually
203 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
204
205 Iterator broadcast_input(broadcast_tensor, broadcast_win);
206 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
207 Iterator output(out, win);
208
209 execute_window_loop(win, [&](const Coordinates &)
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000210 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100211 const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
212 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
213
214 const auto broadcast_value = *reinterpret_cast<const T *>(broadcast_input.ptr());
215 const auto broadcast_value_vec = wrapper::vdup_n(static_cast<T>(broadcast_value), wrapper::traits::vector_128_tag{});
216
217 const float32x4x4_t bf =
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000218 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100219 {
220 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgetlow(broadcast_value_vec))))), voffset2)), vscale2),
221 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgetlow(broadcast_value_vec))))), voffset2)), vscale2),
222 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgethigh(broadcast_value_vec))))), voffset2)), vscale2),
223 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgethigh(broadcast_value_vec))))), voffset2)), vscale2),
224 }
225 };
226 const float bfs = static_cast<int32_t>(broadcast_value - broadcast_qinfo.offset) * broadcast_qinfo.scale;
227
228 // Compute S elements per iteration
229 int x = window_start_x;
230 for(; x <= (window_end_x - window_step_x); x += window_step_x)
231 {
232 const auto a = wrapper::vloadq(non_broadcast_input_ptr + x);
233
234 const float32x4x4_t af =
235 {
236 {
237 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgetlow(a))))), voffset1)), vscale1),
238 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgetlow(a))))), voffset1)), vscale1),
239 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgethigh(a))))), voffset1)), vscale1),
240 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgethigh(a))))), voffset1)), vscale1),
241 }
242 };
243
244 const int32x4x4_t rf =
245 {
246 {
247#ifdef __aarch64__
248 vcvtnq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[0], af.val[0]) : vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
249 vcvtnq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[1], af.val[1]) : vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
250 vcvtnq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[2], af.val[2]) : vsubq_f32(af.val[2], bf.val[2]), invvscaleo)),
251 vcvtnq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[3], af.val[3]) : vsubq_f32(af.val[3], bf.val[3]), invvscaleo)),
252#else //__aarch64__
253 vcvtq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[0], af.val[0]) : vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
254 vcvtq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[1], af.val[1]) : vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
255 vcvtq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[2], af.val[2]) : vsubq_f32(af.val[2], bf.val[2]), invvscaleo)),
256 vcvtq_s32_f32(vmlaq_f32(voffseto, is_broadcast_input_2 ? vsubq_f32(bf.val[3], af.val[3]) : vsubq_f32(af.val[3], bf.val[3]), invvscaleo)),
257#endif //__aarch64__
258 }
259 };
260
261 const auto pa = wrapper::vqmov<T>(vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1])));
262 const auto pb = wrapper::vqmov<T>(vcombine_s16(vqmovn_s32(rf.val[2]), vqmovn_s32(rf.val[3])));
263 wrapper::vstore(output_ptr + x, wrapper::vcombine(pa, pb));
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000264 }
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000265
Michalis Spyrou5f390912020-05-13 00:12:08 +0100266 // Compute left-over elements
267 for(; x < window_end_x; ++x)
268 {
269 const float afs = static_cast<int32_t>(*(non_broadcast_input_ptr + x) - non_broadcast_qinfo.offset) * non_broadcast_qinfo.scale;
270 *(output_ptr + x) = quantize<T>((afs - bfs), out->info()->quantization_info());
271 }
272 },
273 broadcast_input, non_broadcast_input, output);
274 }
275 else
276 {
277 // Clear X Dimension on execution window as we handle manually
278 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
279 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000280
Michalis Spyrou5f390912020-05-13 00:12:08 +0100281 Iterator input1(in1, input1_win);
282 Iterator input2(in2, input2_win);
283 Iterator output(out, win);
284
285 execute_window_loop(win, [&](const Coordinates &)
286 {
287 const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
288 const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
289 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
290
291 // Compute S elements per iteration
292 int x = window_start_x;
293 for(; x <= (window_end_x - window_step_x); x += window_step_x)
294 {
295 const auto a = wrapper::vloadq(input1_ptr + x);
296 const auto b = wrapper::vloadq(input2_ptr + x);
297
298 const float32x4x4_t af =
299 {
300 {
301 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgetlow(a))))), voffset1)), vscale1),
302 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgetlow(a))))), voffset1)), vscale1),
303 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgethigh(a))))), voffset1)), vscale1),
304 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgethigh(a))))), voffset1)), vscale1),
305 }
306 };
307
308 const float32x4x4_t bf =
309 {
310 {
311 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgetlow(b))))), voffset2)), vscale2),
312 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgetlow(b))))), voffset2)), vscale2),
313 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vgethigh(b))))), voffset2)), vscale2),
314 vmulq_f32(vcvtq_f32_s32(vsubq_s32(wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(wrapper::vmovl(wrapper::vgethigh(b))))), voffset2)), vscale2),
315 }
316 };
317
318 const int32x4x4_t rf =
319 {
320 {
321#ifdef __aarch64__
322 vcvtnq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
323 vcvtnq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
324 vcvtnq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[2], bf.val[2]), invvscaleo)),
325 vcvtnq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[3], bf.val[3]), invvscaleo)),
326#else //__aarch64__
327 vcvtq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
328 vcvtq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
329 vcvtq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[2], bf.val[2]), invvscaleo)),
330 vcvtq_s32_f32(vmlaq_f32(voffseto, vsubq_f32(af.val[3], bf.val[3]), invvscaleo)),
331#endif //__aarch64__
332 }
333 };
334
335 const auto pa = wrapper::vqmov<T>(vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1])));
336 const auto pb = wrapper::vqmov<T>(vcombine_s16(vqmovn_s32(rf.val[2]), vqmovn_s32(rf.val[3])));
337 wrapper::vstore(output_ptr + x, wrapper::vcombine(pa, pb));
338 }
339
340 // Compute left-over elements
341 for(; x < window_end_x; ++x)
342 {
343 const float afs = static_cast<int32_t>((*(input1_ptr + x)) - iq1_info.offset) * iq1_info.scale;
344 const float bfs = static_cast<int32_t>((*(input2_ptr + x)) - iq2_info.offset) * iq2_info.scale;
345
346 *(output_ptr + x) = quantize<T>((afs - bfs), out->info()->quantization_info());
347 }
348 },
349 input1, input2, output);
350 }
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000351}
352
Michalis Spyrou5f390912020-05-13 00:12:08 +0100353void sub_QSYMM16_QSYMM16_QSYMM16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, bool is_sat)
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000354{
Michalis Spyrou5f390912020-05-13 00:12:08 +0100355 ARM_COMPUTE_UNUSED(is_sat);
356
357 // Create input windows
358 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
359 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
360
361 // Clear X Dimension on execution window as we handle manually
362 Window win = window;
363 win.set(Window::DimX, Window::Dimension(0, 1, 1));
364
365 const int window_step_x = 8;
366 const auto window_start_x = static_cast<int>(window.x().start());
367 const auto window_end_x = static_cast<int>(window.x().end());
368 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000369
370 const UniformQuantizationInfo iq1_info = in1->info()->quantization_info().uniform();
371 const UniformQuantizationInfo iq2_info = in2->info()->quantization_info().uniform();
372 const UniformQuantizationInfo oq_info = out->info()->quantization_info().uniform();
373
Michalis Spyrou5f390912020-05-13 00:12:08 +0100374 const float32x4_t vscale1 = vdupq_n_f32(iq1_info.scale);
375 const float32x4_t vscale2 = vdupq_n_f32(iq2_info.scale);
376 const float32x4_t invvscaleo = vdupq_n_f32(1.f / oq_info.scale);
377
378 if(is_broadcast_across_x)
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000379 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100380 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
381 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
382 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
383 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
384 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
385 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
386 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000387
Michalis Spyrou5f390912020-05-13 00:12:08 +0100388 // Clear X Dimension on execution window as we handle manually
389 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
390
391 Iterator broadcast_input(broadcast_tensor, broadcast_win);
392 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
393 Iterator output(out, win);
394
395 execute_window_loop(win, [&](const Coordinates &)
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000396 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100397 const auto non_broadcast_input_ptr = reinterpret_cast<const int16_t *>(non_broadcast_input.ptr());
398 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
399
400 const int16_t broadcast_value = *reinterpret_cast<const int16_t *>(broadcast_input.ptr());
401 const int16x8_t broadcast_value_vec = vdupq_n_s16(broadcast_value);
402
403 const float32x4x2_t bf =
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000404 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100405 {
406 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(broadcast_value_vec))), vscale2),
407 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(broadcast_value_vec))), vscale2),
408 }
409 };
410 const float bfs = static_cast<int32_t>(broadcast_value) * broadcast_qinfo.scale;
411
412 // Compute S elements per iteration
413 int x = window_start_x;
414 for(; x <= (window_end_x - window_step_x); x += window_step_x)
415 {
416 const int16x8_t a = vld1q_s16(non_broadcast_input_ptr + x);
417 const float32x4x2_t af =
418 {
419 {
420 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(a))), vscale1),
421 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(a))), vscale1),
422 }
423 };
424
425 const int32x4x4_t rf =
426 {
427 {
428#ifdef __aarch64__
429 vcvtnq_s32_f32(vmulq_f32(is_broadcast_input_2 ? vsubq_f32(bf.val[0], af.val[0]) : vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
430 vcvtnq_s32_f32(vmulq_f32(is_broadcast_input_2 ? vsubq_f32(bf.val[1], af.val[1]) : vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
431#else //__aarch64__
432 vcvtq_s32_f32(vmulq_f32(is_broadcast_input_2 ? vsubq_f32(bf.val[0], af.val[0]) : vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
433 vcvtq_s32_f32(vmulq_f32(is_broadcast_input_2 ? vsubq_f32(bf.val[1], af.val[1]) : vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
434#endif //__aarch64__
435 }
436 };
437
438 const int16x8_t pa = vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1]));
439 vst1q_s16(output_ptr + x, pa);
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000440 }
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000441
Michalis Spyrou5f390912020-05-13 00:12:08 +0100442 // Compute left-over elements
443 for(; x < window_end_x; ++x)
444 {
445 const float afs = static_cast<int32_t>(*(non_broadcast_input_ptr + x)) * non_broadcast_qinfo.scale;
446 *(output_ptr + x) = quantize_qsymm16(is_broadcast_input_2 ? (bfs - afs) : (afs - bfs), oq_info);
447 }
448 },
449 broadcast_input, non_broadcast_input, output);
450 }
451 else
452 {
453 // Clear X Dimension on execution window as we handle manually
454 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
455 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000456
Michalis Spyrou5f390912020-05-13 00:12:08 +0100457 Iterator input1(in1, input1_win);
458 Iterator input2(in2, input2_win);
459 Iterator output(out, win);
460
461 execute_window_loop(win, [&](const Coordinates &)
462 {
463 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
464 const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
465 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
466
467 // Compute S elements per iteration
468 int x = window_start_x;
469 for(; x <= (window_end_x - window_step_x); x += window_step_x)
470 {
471 const int16x8_t a = vld1q_s16(input1_ptr + x);
472 const int16x8_t b = vld1q_s16(input2_ptr + x);
473
474 const float32x4x2_t af =
475 {
476 {
477 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(a))), vscale1),
478 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(a))), vscale1),
479 }
480 };
481
482 const float32x4x2_t bf =
483 {
484 {
485 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(b))), vscale2),
486 vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(b))), vscale2),
487 }
488 };
489
490 const int32x4x2_t rf =
491 {
492 {
493#ifdef __aarch64__
494 vcvtnq_s32_f32(vmulq_f32(vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
495 vcvtnq_s32_f32(vmulq_f32(vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
496#else //__aarch64__
497 vcvtq_s32_f32(vmulq_f32(vsubq_f32(af.val[0], bf.val[0]), invvscaleo)),
498 vcvtq_s32_f32(vmulq_f32(vsubq_f32(af.val[1], bf.val[1]), invvscaleo)),
499#endif //__aarch64__
500 }
501 };
502
503 const int16x8_t pa = vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1]));
504 vst1q_s16(output_ptr + x, pa);
505 }
506
507 // Compute left-over elements
508 for(; x < window_end_x; ++x)
509 {
510 const float afs = static_cast<int32_t>((*(input1_ptr + x))) * iq1_info.scale;
511 const float bfs = static_cast<int32_t>((*(input2_ptr + x))) * iq2_info.scale;
512 *(output_ptr + x) = quantize_qsymm16((afs - bfs), out->info()->quantization_info());
513 }
514 },
515 input1, input2, output);
516 }
517}
518
519void sub_S16_U8_S16_impl(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, bool is_sat, bool is_swapped)
520{
521 // Create input windows
522 Window win = window;
523 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
524 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
525
526 // Clear X Dimension on execution window as we handle manually
527 win.set(Window::DimX, Window::Dimension(0, 1, 1));
528 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
529 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
530
531 Iterator input1(in1, input1_win);
532 Iterator input2(in2, input2_win);
533 Iterator output(out, win);
534
535 const int window_step_x = 8;
536 const auto window_start_x = static_cast<int>(window.x().start());
537 const auto window_end_x = static_cast<int>(window.x().end());
538
539 execute_window_loop(win, [&](const Coordinates &)
540 {
541 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
542 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
543 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
544
545 if(!is_sat)
546 {
547 // Compute S elements per iteration
548 int x = window_start_x;
549 for(; x <= (window_end_x - window_step_x); x += window_step_x)
550 {
551 const auto vin1 = wrapper::vloadq(input1_ptr + x);
552 const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
553 const auto res = is_swapped ? wrapper::vsub(vin2, vin1) : wrapper::vsub(vin1, vin2);
554 wrapper::vstore(output_ptr + x, res);
555 }
556
557 // Compute left-over elements
558 for(; x < window_end_x; ++x)
559 {
560 const auto res = is_swapped ? static_cast<int16_t>(*(input2_ptr + x)) - *(input1_ptr + x) : *(input1_ptr + x) - static_cast<int16_t>(*(input2_ptr + x));
561 *(output_ptr + x) = res;
562 }
563 }
564 else
565 {
566 // Compute S elements per iteration
567 int x = window_start_x;
568 for(; x <= (window_end_x - window_step_x); x += window_step_x)
569 {
570 const auto vin1 = wrapper::vloadq(input1_ptr + x);
571 const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
572 const auto res = is_swapped ? wrapper::vqsub(vin2, vin1) : wrapper::vqsub(vin1, vin2);
573 wrapper::vstore(output_ptr + x, res);
574 }
575
576 // Compute left-over elements
577 for(; x < window_end_x; ++x)
578 {
579 const auto res = is_swapped ? wrapper::sub_sat(static_cast<int16_t>(*(input2_ptr + x)), *(input1_ptr + x)) : wrapper::sub_sat(*(input1_ptr + x), static_cast<int16_t>(*(input2_ptr + x)));
580 *(output_ptr + x) = res;
581 }
582 }
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000583 },
584 input1, input2, output);
585}
586
Michalis Spyrou5f390912020-05-13 00:12:08 +0100587void sub_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, bool is_sat)
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000588{
Michalis Spyrou5f390912020-05-13 00:12:08 +0100589 sub_S16_U8_S16_impl(in1, in2, out, window, is_sat, false);
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000590}
591
Michalis Spyrou5f390912020-05-13 00:12:08 +0100592void sub_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, bool is_sat)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100593{
Michalis Spyrou5f390912020-05-13 00:12:08 +0100594 // Swap arguments
595 sub_S16_U8_S16_impl(in2, in1, out, window, is_sat, true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100596}
597
Michalis Spyrou5f390912020-05-13 00:12:08 +0100598void sub_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, bool is_sat)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100599{
Michalis Spyrou5f390912020-05-13 00:12:08 +0100600 // Create input windows
601 Window win = window;
602 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
603 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100604
Michalis Spyrou5f390912020-05-13 00:12:08 +0100605 // Clear X Dimension on execution window as we handle manually
606 win.set(Window::DimX, Window::Dimension(0, 1, 1));
607 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
608 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
609
610 Iterator input1(in1, input1_win);
611 Iterator input2(in2, input2_win);
612 Iterator output(out, win);
613
614 const int window_step_x = 8;
615 const auto window_start_x = static_cast<int>(window.x().start());
616 const auto window_end_x = static_cast<int>(window.x().end());
617
618 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100619 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100620 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
621 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
622 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
623
624 if(!is_sat)
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000625 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100626 // Compute S elements per iteration
627 int x = window_start_x;
628 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000629 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100630 const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x)));
631 const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
632 wrapper::vstore(output_ptr + x, wrapper::vsub(vin1, vin2));
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000633 }
Michalis Spyrou5f390912020-05-13 00:12:08 +0100634
635 // Compute left-over elements
636 for(; x < window_end_x; ++x)
637 {
638 *(output_ptr + x) = static_cast<int16_t>(*(input1_ptr + x)) - static_cast<int16_t>(*(input2_ptr + x));
639 }
640 }
641 else
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000642 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100643 // Compute S elements per iteration
644 int x = window_start_x;
645 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000646 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100647 const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x)));
648 const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
649 wrapper::vstore(output_ptr + x, wrapper::vqsub(vin1, vin2));
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000650 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100651
Michalis Spyrou5f390912020-05-13 00:12:08 +0100652 // Compute left-over elements
653 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100654 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100655 *(output_ptr + x) = wrapper::sub_sat(static_cast<int16_t>(*(input1_ptr + x)),
656 static_cast<int16_t>(*(input2_ptr + x)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100657 }
Michalis Spyrou5f390912020-05-13 00:12:08 +0100658 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100659 },
660 input1, input2, output);
661}
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000662
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100663inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000664{
665 ARM_COMPUTE_UNUSED(policy);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100666 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000667 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
668 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
669 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::F16, DataType::F32);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000670
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100671 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
672 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000673
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000674 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
675 !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8)
676 && !(input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8)
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000677 && !(input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED)
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000678 && !(input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16)
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000679 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8)
680 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16)
681 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8)
682 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16)
683 && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32)
684 && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16),
685 "You called subtract with the wrong image formats");
686
687 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000688 input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000689 && input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP
690 && input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP,
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000691 "Convert policy cannot be WRAP if datatype is QASYMM8 or QASYMM8_SIGNED");
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000692
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100693 // Validate in case of configured output
694 if(output.total_size() > 0)
695 {
696 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
697 !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
Manuel Bottini6a2b6e82019-02-25 13:50:11 +0000698 && !(input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && output.data_type() == DataType::QASYMM8)
Michalis Spyrou6f58b372019-12-04 12:00:36 +0000699 && !(input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && output.data_type() == DataType::QASYMM8_SIGNED)
Michele Di Giorgio9f2403f2020-03-27 10:23:44 +0000700 && !(input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && output.data_type() == DataType::QSYMM16)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100701 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
702 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
703 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
704 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
705 && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
706 && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
707 "You called subtract with the wrong image formats");
708
709 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
710 "Wrong shape for output");
711 }
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000712 return Status{};
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000713}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100714} // namespace
715
716NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel()
Michalis Spyrou5f390912020-05-13 00:12:08 +0100717 : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _policy(ConvertPolicy::WRAP)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100718{
719}
720
721void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
722{
723 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100724 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100725
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100726 _input1 = input1;
727 _input2 = input2;
728 _output = output;
Michalis Spyrou5f390912020-05-13 00:12:08 +0100729 _policy = policy;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100730
Michalis Spyrou5f390912020-05-13 00:12:08 +0100731 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1->info(), *input2->info());
732 const TensorShape &out_shape = broadcast_pair.first;
733 const ValidRegion &valid_region = broadcast_pair.second;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100734
Michalis Spyrou5f390912020-05-13 00:12:08 +0100735 // Auto initialize output if not initialized
736 set_shape_if_empty(*output->info(), out_shape);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100737
Michalis Spyrou5f390912020-05-13 00:12:08 +0100738 switch(input1->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100739 {
Michalis Spyrou5f390912020-05-13 00:12:08 +0100740 case DataType::U8:
741 if(input2->info()->data_type() == DataType::U8 && output->info()->data_type() == DataType::U8)
742 {
743 _func = &sub_same<uint8_t>;
744 }
745 else if(input2->info()->data_type() == DataType::U8 && output->info()->data_type() == DataType::S16)
746 {
747 _func = &sub_U8_U8_S16;
748 }
749 else
750 {
751 _func = &sub_U8_S16_S16;
752 }
753 break;
754 case DataType::QASYMM8:
755 _func = &sub_quantized<uint8_t>;
756 set_data_type_if_unknown(*output->info(), DataType::QASYMM8);
757 break;
758 case DataType::QASYMM8_SIGNED:
759 _func = &sub_quantized<int8_t>;
760 set_data_type_if_unknown(*output->info(), DataType::QASYMM8_SIGNED);
761 break;
762 case DataType::S16:
763 if(input2->info()->data_type() == DataType::U8)
764 {
765 _func = &sub_S16_U8_S16;
766 }
767 else
768 {
769 _func = &sub_same<int16_t>;
770 }
771 set_format_if_unknown(*output->info(), Format::S16);
772 break;
773 case DataType::QSYMM16:
774 _func = &sub_QSYMM16_QSYMM16_QSYMM16;
775 set_data_type_if_unknown(*output->info(), DataType::QSYMM16);
776 break;
777#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
778 case DataType::F16:
779 _func = &sub_same<float16_t>;
780 set_format_if_unknown(*output->info(), Format::F16);
781 break;
782#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
783 case DataType::F32:
784 _func = &sub_same<float>;
785 set_format_if_unknown(*output->info(), Format::F32);
786 break;
787 default:
788 _func = nullptr;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100789 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100790
Michalis Spyrou5f390912020-05-13 00:12:08 +0100791 // NEArithmeticSubtractionKernel doesn't need padding so update_window_and_padding() can be skipped
792 Coordinates coord;
793 coord.set_num_dimensions(output->info()->num_dimensions());
794 output->info()->set_valid_region(valid_region);
795 Window win = calculate_max_window(valid_region, Steps());
796
797 INEKernel::configure(win);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000798}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100799
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000800Status NEArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000801{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100802 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100803 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100804
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000805 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100806}
807
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100808void NEArithmeticSubtractionKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100809{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100810 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100811 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
812 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
813 ARM_COMPUTE_ERROR_ON(_func == nullptr);
814
Michalis Spyrou5f390912020-05-13 00:12:08 +0100815 (*_func)(_input1, _input2, _output, window, (_policy == ConvertPolicy::SATURATE));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100816}
Michalis Spyrou5f390912020-05-13 00:12:08 +0100817} // namespace arm_compute