blob: ff8fb84958bed23ab6b9291bc34b1c2fa9603e60 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002 * Copyright (c) 2016-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
Michele Di Giorgio81f0d152017-07-11 15:00:52 +010030#include "arm_compute/core/NEON/NEFixedPoint.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/TensorInfo.h"
32#include "arm_compute/core/Validate.h"
33
34#include <algorithm>
35#include <arm_neon.h>
36#include <cstdint>
37#include <map>
38#include <string>
39
40using namespace arm_compute;
41
42namespace arm_compute
43{
44class Coordinates;
45} // namespace arm_compute
46
47namespace
48{
Georgios Pinitascbf39c62018-09-10 15:07:45 +010049constexpr unsigned int num_elems_processed_per_iteration = 16;
50
Anthony Barbier6ff3b192017-09-04 18:44:23 +010051void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
52{
Georgios Pinitascbf39c62018-09-10 15:07:45 +010053 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
54 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010055 Iterator output(out, window);
56
57 execute_window_loop(window, [&](const Coordinates & id)
58 {
59 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
60 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
61
62 vst1q_u8(output.ptr(), vsubq_u8(ta1, ta2));
63 },
64 input1, input2, output);
65}
66
67void sub_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
68{
Georgios Pinitascbf39c62018-09-10 15:07:45 +010069 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
70 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010071 Iterator output(out, window);
72
73 execute_window_loop(window, [&](const Coordinates & id)
74 {
75 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
76 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
77
78 vst1q_u8(output.ptr(), vqsubq_u8(ta1, ta2));
79 },
80 input1, input2, output);
81}
82
83void sub_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
84{
Georgios Pinitascbf39c62018-09-10 15:07:45 +010085 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
86 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010087 Iterator output(out, window);
88
89 execute_window_loop(window, [&](const Coordinates & id)
90 {
91 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
92 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
93
94 const int16x8x2_t ta3 =
95 {
96 {
97 vsubq_s16(ta1.val[0], ta2.val[0]),
98 vsubq_s16(ta1.val[1], ta2.val[1])
99 }
100 };
101
102 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
103 },
104 input1, input2, output);
105}
106
107void sub_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
108{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100109 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
110 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111 Iterator output(out, window);
112
113 execute_window_loop(window, [&](const Coordinates & id)
114 {
115 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
116 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
117
118 const int16x8x2_t ta3 =
119 {
120 {
121 vqsubq_s16(ta1.val[0], ta2.val[0]),
122 vqsubq_s16(ta1.val[1], ta2.val[1])
123 }
124 };
125
126 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
127 },
128 input1, input2, output);
129}
130
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000131#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellod7a5d222017-07-11 13:54:43 +0100132inline float16x8x2_t vsub2q_f16(const float16x8x2_t &a, const float16x8x2_t &b)
133{
134 const float16x8x2_t res =
135 {
136 {
137 vsubq_f16(a.val[0], b.val[0]),
138 vsubq_f16(a.val[1], b.val[1])
139 }
140 };
141
142 return res;
143}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000144#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellod7a5d222017-07-11 13:54:43 +0100145
146void sub_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
147{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000148#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100149 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
150 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Pablo Tellod7a5d222017-07-11 13:54:43 +0100151 Iterator output(out, window);
152
153 execute_window_loop(window, [&](const Coordinates & id)
154 {
155 const float16x8x2_t a = vld2q_f16(reinterpret_cast<const float16_t *>(input1.ptr()));
156 const float16x8x2_t b = vld2q_f16(reinterpret_cast<const float16_t *>(input2.ptr()));
157
158 vst2q_f16(reinterpret_cast<float16_t *>(output.ptr()), vsub2q_f16(a, b));
159 },
160 input1, input2, output);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000161#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellod7a5d222017-07-11 13:54:43 +0100162 ARM_COMPUTE_UNUSED(in1);
163 ARM_COMPUTE_UNUSED(in2);
164 ARM_COMPUTE_UNUSED(out);
165 ARM_COMPUTE_UNUSED(window);
166 ARM_COMPUTE_ERROR("Not supported, recompile the library with arch=arm64-v8.2-a");
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000167#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellod7a5d222017-07-11 13:54:43 +0100168}
169
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100170void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
171{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100172 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
173 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100174 Iterator output(out, window);
175
176 execute_window_loop(window, [&](const Coordinates & id)
177 {
178 const float32x4x4_t ta1 = vld4q_f32(reinterpret_cast<const float *>(input1.ptr()));
179 const float32x4x4_t ta2 = vld4q_f32(reinterpret_cast<const float *>(input2.ptr()));
180
181 const float32x4x4_t ta3 =
182 {
183 {
184 vsubq_f32(ta1.val[0], ta2.val[0]),
185 vsubq_f32(ta1.val[1], ta2.val[1]),
186 vsubq_f32(ta1.val[2], ta2.val[2]),
187 vsubq_f32(ta1.val[3], ta2.val[3]),
188 }
189 };
190
191 vst4q_f32(reinterpret_cast<float *>(output.ptr()), ta3);
192 },
193 input1, input2, output);
194}
195void sub_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
196{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100197 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
198 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199 Iterator output(out, window);
200
201 execute_window_loop(window, [&](const Coordinates & id)
202 {
203 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
204 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
205 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
206
207 a1_0 = vsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
208 a2_0 = vsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
209
210 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
211 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
212 },
213 input1, input2, output);
214}
215
216void sub_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
217{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100218 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
219 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100220 Iterator output(out, window);
221
222 execute_window_loop(window, [&](const Coordinates & id)
223 {
224 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
225 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
226 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
227
228 a1_0 = vqsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
229 a2_0 = vqsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
230
231 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
232 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
233 },
234 input1, input2, output);
235}
236
237void sub_wrap_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
238{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100239 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
240 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100241 Iterator output(out, window);
242
243 execute_window_loop(window, [&](const Coordinates & id)
244 {
245 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
246 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
247 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
248
249 a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
250 a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
251
252 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
253 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
254 },
255 input1, input2, output);
256}
257
258void sub_saturate_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
259{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100260 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
261 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100262 Iterator output(out, window);
263
264 execute_window_loop(window, [&](const Coordinates & id)
265 {
266 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
267 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
268 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
269
270 a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
271 a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
272
273 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
274 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
275 },
276 input1, input2, output);
277}
278
279void sub_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
280{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100281 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
282 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100283 Iterator output(out, window);
284
285 execute_window_loop(window, [&](const Coordinates & id)
286 {
287 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
288 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
289
290 const int16x8_t a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
291 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
292 const int16x8_t a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
293 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
294
295 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
296 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
297 },
298 input1, input2, output);
299}
300
301void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
302{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100303 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
304 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100305 Iterator output(out, window);
306
307 execute_window_loop(window, [&](const Coordinates & id)
308 {
309 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
310 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
311
312 const int16x8_t a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
313 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
314 const int16x8_t a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
315 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
316
317 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
318 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
319 },
320 input1, input2, output);
321}
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000322
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100323inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000324{
325 ARM_COMPUTE_UNUSED(policy);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100326 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
327 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
328 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
329 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000330
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100331 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
332 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000333
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100334 // Validate in case of configured output
335 if(output.total_size() > 0)
336 {
337 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
338 !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
339 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
340 && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
341 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
342 && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
343 && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
344 && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
345 "You called subtract with the wrong image formats");
346
347 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
348 "Wrong shape for output");
349 }
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000350 return Status{};
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000351}
352
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100353inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000354{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100355 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
356 const TensorShape &out_shape = broadcast_pair.first;
357 const ValidRegion &valid_region = broadcast_pair.second;
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000358
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100359 // Auto initialize output if not initialized
360 {
361 set_shape_if_empty(output, out_shape);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000362
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100363 if(input1.data_type() == DataType::S16 || input2.data_type() == DataType::S16)
364 {
365 set_format_if_unknown(output, Format::S16);
366 }
367 else if(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16)
368 {
369 set_format_if_unknown(output, Format::F16);
370 }
371 else if(input1.data_type() == DataType::F32 || input2.data_type() == DataType::F32)
372 {
373 set_format_if_unknown(output, Format::F32);
374 }
375 }
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000376
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100377 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration));
378 Window win_input1 = win.broadcast_if_dimension_le_one(input1);
379 Window win_input2 = win.broadcast_if_dimension_le_one(input2);
380
381 AccessWindowHorizontal input1_access(&input1, 0, num_elems_processed_per_iteration);
382 AccessWindowHorizontal input2_access(&input2, 0, num_elems_processed_per_iteration);
383 AccessWindowHorizontal output_access(&output, 0, num_elems_processed_per_iteration);
384
385 bool window_changed = update_window_and_padding(win_input1, input1_access)
386 || update_window_and_padding(win_input2, input2_access)
387 || update_window_and_padding(win, output_access);
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000388
389 output_access.set_valid_region(win, valid_region);
390
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000391 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000392 return std::make_pair(err, win);
393}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100394} // namespace
395
396NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel()
397 : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
398{
399}
400
401void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
402{
403 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100404 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100405
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100406 // Configure kernel window
407 auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info());
408 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100409
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000410 static std::map<std::string, NEArithmeticSubtractionKernel::SubFunction *> map_function =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100411 {
412 { "sub_wrap_U8_U8_U8", &sub_wrap_U8_U8_U8 },
413 { "sub_wrap_U8_U8_S16", &sub_wrap_U8_U8_S16 },
414 { "sub_saturate_U8_U8_U8", &sub_saturate_U8_U8_U8 },
415 { "sub_saturate_U8_U8_S16", &sub_saturate_U8_U8_S16 },
416 { "sub_wrap_U8_S16_S16", &sub_wrap_U8_S16_S16 },
417 { "sub_wrap_S16_U8_S16", &sub_wrap_S16_U8_S16 },
418 { "sub_saturate_U8_S16_S16", &sub_saturate_U8_S16_S16 },
419 { "sub_saturate_S16_U8_S16", &sub_saturate_S16_U8_S16 },
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100420 { "sub_wrap_S16_S16_S16", &sub_wrap_S16_S16_S16 },
421 { "sub_saturate_S16_S16_S16", &sub_saturate_S16_S16_S16 },
422 { "sub_wrap_F32_F32_F32", &sub_F32_F32_F32 },
423 { "sub_saturate_F32_F32_F32", &sub_F32_F32_F32 },
Pablo Tellod7a5d222017-07-11 13:54:43 +0100424 { "sub_wrap_F16_F16_F16", &sub_F16_F16_F16 },
425 { "sub_saturate_F16_F16_F16", &sub_F16_F16_F16 },
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100426 };
427
428 _input1 = input1;
429 _input2 = input2;
430 _output = output;
431
432 std::string function_to_call("sub_");
433 function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_";
434 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
435 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
436 function_to_call += string_from_data_type(output->info()->data_type());
437
438 auto it = map_function.find(function_to_call);
439
440 if(it != map_function.end())
441 {
442 _func = it->second;
443 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100444
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000445 INEKernel::configure(win_config.second);
446}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100447
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000448Status NEArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
Ioan-Cristian Szabo397d58a2017-11-30 15:19:11 +0000449{
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100450 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
451
452 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy));
453 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*input1->clone(), *input2->clone(), *output->clone()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100454
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000455 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100456}
457
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100458void NEArithmeticSubtractionKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100459{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100460 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100461 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
462 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
463 ARM_COMPUTE_ERROR_ON(_func == nullptr);
464
465 (*_func)(_input1, _input2, _output, window);
466}
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100467
468BorderSize NEArithmeticSubtractionKernel::border_size() const
469{
470 const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
471 const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
472 return BorderSize(0, border, 0, 0);
473}