blob: 7c9514723b2464f42ec39b1a8d8b6493e3b603c5 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/TensorInfo.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h"
34
35#include <arm_neon.h>
36#include <climits>
37#include <cmath>
38#include <cstdint>
39#include <cstdlib>
40
41using namespace arm_compute;
42
43namespace arm_compute
44{
45class Coordinates;
46} // namespace arm_compute
47
48namespace
49{
50const float scale255_constant = 1.f / 255.f;
51const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
52const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
53
54/* Scales a given vector by 1/255.
55 *
56 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
57 *
58 * @param in Input vector to scale.
59 * @return Scaled output rounded to nearest (round half up).
60 */
61inline int32x4_t scale255_S32_S32(int32x4_t in)
62{
63 // Scale
64 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
65 // Round to nearest (round half up)
66 // Add +0.5 for all values
67 // Afterwards vcvt rounds toward zero
68 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
69}
70
71inline uint16x8_t scale255_U16_U16(uint16x8_t in)
72{
73 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
74 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
75 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
76}
77
78template <bool is_scale255, bool is_sat>
79void mul_U8_U8_U8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
80{
81 const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
82 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
83 const auto output = static_cast<uint8_t *__restrict>(output_ptr);
84
85 const uint8x16_t ta1 = vld1q_u8(input1);
86 const uint8x16_t ta2 = vld1q_u8(input2);
87
88 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
89 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
90 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
91 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
92
93 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
94 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
95
96 if(is_scale255)
97 {
98 tmp1_high = scale255_U16_U16(tmp1_high);
99 tmp1_low = scale255_U16_U16(tmp1_low);
100 }
101 else
102 {
103 const int16x8_t vn = vdupq_n_s16(-n);
104
105 if(is_sat)
106 {
107 tmp1_high = vqshlq_u16(tmp1_high, vn);
108 tmp1_low = vqshlq_u16(tmp1_low, vn);
109 }
110 else
111 {
112 tmp1_high = vshlq_u16(tmp1_high, vn);
113 tmp1_low = vshlq_u16(tmp1_low, vn);
114 }
115 }
116
117 if(is_sat)
118 {
119 vst1q_u8(output, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
120 }
121 else
122 {
123 vst1q_u8(output, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
124 }
125}
126
127template <bool is_scale255, bool is_sat>
128void mul_QS8_QS8_QS8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
129{
130 // n is the exponent of the scaling factor, that is scale = 1/2^n. Currently, we only support scaling factor equal to 1 => n = 0.
131 ARM_COMPUTE_ERROR_ON_MSG(n != 0, "Scaling factor different than 1 not supported for 8-bit fixed-point pixel-wise multiplication");
132 ARM_COMPUTE_UNUSED(n);
133
134 const auto input1 = static_cast<const qint8_t *__restrict>(input1_ptr);
135 const auto input2 = static_cast<const qint8_t *__restrict>(input2_ptr);
136 const auto output = static_cast<qint8_t *__restrict>(output_ptr);
137
138 const qint8x16_t ta1 = vld1q_qs8(input1);
139 const qint8x16_t ta2 = vld1q_qs8(input2);
140
141 qint8x16_t res = (is_sat) ? vqmulq_qs8(ta1, ta2, fixed_point_position) : vmulq_qs8(ta1, ta2, fixed_point_position);
142
143 vst1q_s8(output, res);
144}
145
146template <bool is_scale255, bool is_sat>
147inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
148{
149 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
150 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
151 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
152 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
153
154 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
155 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
156
157 if(is_scale255)
158 {
159 tmp1_high = scale255_S32_S32(tmp1_high);
160 tmp1_low = scale255_S32_S32(tmp1_low);
161 }
162 else
163 {
164 // Right shift amount
165 const int32x4_t vn = vdupq_n_s32(-n);
166 // Left shift amount
167 const int32x4_t vnl = vdupq_n_s32(n);
168 // Calculate conversion bit
169 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
170 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
171 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
172 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
173 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
174 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
175 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
176 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
177 if(is_sat)
178 {
179 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
180 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
181 }
182 else
183 {
184 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
185 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
186 }
187 }
188
189 if(is_sat)
190 {
191 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
192 }
193 else
194 {
195 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
196 }
197}
198
199template <bool is_scale255, bool is_sat>
200inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
201{
202 const int16x8x2_t result =
203 {
204 {
205 // First 8 elements
206 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
207 // Second 8 elements
208 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
209 }
210 };
211
212 return result;
213}
214
215template <bool is_scale255, bool is_sat>
216void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
217{
218 const auto input1 = static_cast<const int16_t *__restrict>(input1_ptr);
219 const auto input2 = static_cast<const int16_t *__restrict>(input2_ptr);
220 const auto output = static_cast<int16_t *__restrict>(output_ptr);
221
222 const int16x8x2_t ta1 = vld2q_s16(input1);
223 const int16x8x2_t ta2 = vld2q_s16(input2);
224 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
225
226 vst2q_s16(output, result);
227}
228
229template <bool is_scale255, bool is_sat>
230void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
231{
232 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
233 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
234 const auto output = static_cast<float *__restrict>(output_ptr);
235
236 const float32x4x4_t ta1 = vld4q_f32(input1);
237 const float32x4x4_t ta2 = vld4q_f32(input2);
238 const float32x4_t scale_vec = vdupq_n_f32(scale);
239 const float32x4x4_t result =
240 {
241 {
242 vmulq_f32(vmulq_f32(ta1.val[0], ta2.val[0]), scale_vec),
243 vmulq_f32(vmulq_f32(ta1.val[1], ta2.val[1]), scale_vec),
244 vmulq_f32(vmulq_f32(ta1.val[2], ta2.val[2]), scale_vec),
245 vmulq_f32(vmulq_f32(ta1.val[3], ta2.val[3]), scale_vec)
246 }
247 };
248 vst4q_f32(output, result);
249}
250
251template <bool is_scale255, bool is_sat>
252void mul_U8_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
253{
254 const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
255 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
256 const auto output = static_cast<int16_t *__restrict>(output_ptr);
257
258 const uint8x16_t bv = vld1q_u8(input2);
259 const uint8x16_t av = vld1q_u8(input1);
260
261 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
262 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
263 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
264 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
265
266 if(is_scale255)
267 {
268 tmp_low = scale255_U16_U16(tmp_low);
269 tmp_high = scale255_U16_U16(tmp_high);
270 }
271 else
272 {
273 const int16x8_t vn = vdupq_n_s16(-n);
274
275 if(is_sat)
276 {
277 tmp_low = vqshlq_u16(tmp_low, vn);
278 tmp_high = vqshlq_u16(tmp_high, vn);
279 }
280 else
281 {
282 tmp_low = vshlq_u16(tmp_low, vn);
283 tmp_high = vshlq_u16(tmp_high, vn);
284 }
285 }
286
287 if(is_sat)
288 {
289 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
290
291 tmp_low = vminq_u16(tmp_low, max);
292 tmp_high = vminq_u16(tmp_high, max);
293 }
294
295 vst1q_s16(output, vreinterpretq_s16_u16(tmp_low));
296 vst1q_s16(output + 8, vreinterpretq_s16_u16(tmp_high));
297}
298
299template <bool is_scale255, bool is_sat>
300void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
301{
302 const auto input1 = static_cast<const int16_t *__restrict>(input1_ptr);
303 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
304 const auto output = static_cast<int16_t *__restrict>(output_ptr);
305
306 const int16x8x2_t ta1 = vld2q_s16(input1);
307 const uint8x8x2_t ta2u = vld2_u8(input2);
308 const int16x8x2_t ta2 =
309 {
310 {
311 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
312 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
313 }
314 };
315
316 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
317
318 vst2q_s16(output, result);
319}
320
321template <bool is_scale255, bool is_sat>
322void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
323{
324 // Simply swap the two input buffers
325 mul_S16_U8_S16_n<is_scale255, is_sat>(input2_ptr, input1_ptr, output_ptr, n);
326}
327} // namespace
328
329NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
330 : _func_float(nullptr), _func_int(nullptr), _func_q_int(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
331{
332}
333
334void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
335{
Georgios Pinitasf0dea702017-07-03 18:17:28 +0100336 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
337
338 // Auto initialize output if not initialized
339 {
340 set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
341
342 if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
343 {
344 set_format_if_unknown(*output->info(), Format::S16);
345 }
346 else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
347 {
348 set_format_if_unknown(*output->info(), Format::F32);
349 }
350 else if(input1->info()->data_type() == DataType::QS8 && input2->info()->data_type() == DataType::QS8)
351 {
352 set_data_type_if_unknown(*output->info(), DataType::QS8);
353 set_fixed_point_position_if_zero(*output->info(), input1->info()->fixed_point_position());
354 }
355 }
356
357 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100358 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
359 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
360 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
361 ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
362 "Output can only be U8 if both inputs are U8");
363 if(output->info()->data_type() == DataType::QS8 || input1->info()->data_type() == DataType::QS8 || output->info()->data_type() == DataType::QS8)
364 {
365 // All data types must be QS8
366 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output);
367 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input1, input2, output);
368 }
369
370 _input1 = input1;
371 _input2 = input2;
372 _output = output;
373 _scale = scale;
374 _scale_exponent = 0;
375 _func_int = nullptr;
376 _func_q_int = nullptr;
377 _func_float = nullptr;
378
379 bool is_scale_255 = false;
380 // Check and validate scaling factor
381 if(std::abs(scale - scale255_constant) < 0.00001f)
382 {
383 ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
384 ARM_COMPUTE_UNUSED(rounding_policy);
385
386 is_scale_255 = true;
387 }
388 else
389 {
390 ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
391 ARM_COMPUTE_UNUSED(rounding_policy);
392
393 int exponent = 0;
394 const float normalized_mantissa = std::frexp(scale, &exponent);
395
396 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
397 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
398 // Moreover, it will be negative as we deal with 1/2^n
399 if((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1))
400 {
401 // Store the positive exponent. We know that we compute 1/2^n
402 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
403 _scale_exponent = std::abs(exponent - 1);
404 }
405 else
406 {
407 ARM_COMPUTE_ERROR("Scale value not supported (Should be 1/(2^n) or 1/255");
408 }
409 }
410
411 const DataType dt_input1 = input1->info()->data_type();
412 const DataType dt_input2 = input2->info()->data_type();
413 const DataType dt_output = output->info()->data_type();
414 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
415
416 if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output)
417 {
418 if(is_scale_255)
419 {
420 _func_int = is_sat ? &mul_U8_U8_U8_n<true, true> : &mul_U8_U8_U8_n<true, false>;
421 }
422 else
423 {
424 _func_int = is_sat ? &mul_U8_U8_U8_n<false, true> : &mul_U8_U8_U8_n<false, false>;
425 }
426 }
427 else if(DataType::S16 == dt_input1 && DataType::S16 == dt_input2 && DataType::S16 == dt_output)
428 {
429 if(is_scale_255)
430 {
431 _func_int = is_sat ? &mul_S16_S16_S16_n<true, true> : &mul_S16_S16_S16_n<true, false>;
432 }
433 else
434 {
435 _func_int = is_sat ? &mul_S16_S16_S16_n<false, true> : &mul_S16_S16_S16_n<false, false>;
436 }
437 }
438 else if(DataType::S16 == dt_input1 && DataType::U8 == dt_input2 && DataType::S16 == dt_output)
439 {
440 if(is_scale_255)
441 {
442 _func_int = is_sat ? &mul_S16_U8_S16_n<true, true> : &mul_S16_U8_S16_n<true, false>;
443 }
444 else
445 {
446 _func_int = is_sat ? &mul_S16_U8_S16_n<false, true> : &mul_S16_U8_S16_n<false, false>;
447 }
448 }
449 else if(DataType::U8 == dt_input1 && DataType::S16 == dt_input2 && DataType::S16 == dt_output)
450 {
451 if(is_scale_255)
452 {
453 _func_int = is_sat ? &mul_U8_S16_S16_n<true, true> : &mul_U8_S16_S16_n<true, false>;
454 }
455 else
456 {
457 _func_int = is_sat ? &mul_U8_S16_S16_n<false, true> : &mul_U8_S16_S16_n<false, false>;
458 }
459 }
460 else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::S16 == dt_output)
461 {
462 if(is_scale_255)
463 {
464 _func_int = is_sat ? &mul_U8_U8_S16_n<true, true> : &mul_U8_U8_S16_n<true, false>;
465 }
466 else
467 {
468 _func_int = is_sat ? &mul_U8_U8_S16_n<false, true> : &mul_U8_U8_S16_n<false, false>;
469 }
470 }
471 else if(DataType::QS8 == dt_input1 && DataType::QS8 == dt_input2 && DataType::QS8 == dt_output)
472 {
473 if(is_scale_255)
474 {
475 _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<true, true> : &mul_QS8_QS8_QS8_n<true, false>;
476 }
477 else
478 {
479 _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<false, true> : &mul_QS8_QS8_QS8_n<false, false>;
480 }
481 }
482 else if(DataType::F32 == dt_input1 && DataType::F32 == dt_input2 && DataType::F32 == dt_output)
483 {
484 _func_float = &mul_F32_F32_F32_n<false, false>;
485 _func_int = nullptr;
486 }
487 else
488 {
489 ARM_COMPUTE_ERROR("You called with the wrong img formats");
490 }
491
492 constexpr unsigned int num_elems_processed_per_iteration = 16;
493
494 // Configure kernel window
495 Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
496 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
497
498 update_window_and_padding(win,
499 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
500 AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
501 output_access);
502
503 ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
504 input2->info()->valid_region());
505
506 output_access.set_valid_region(win, valid_region);
507
508 INEKernel::configure(win);
509}
510
511void NEPixelWiseMultiplicationKernel::run(const Window &window)
512{
513 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
514 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
515
516 Iterator input1(_input1, window);
517 Iterator input2(_input2, window);
518 Iterator output(_output, window);
519
520 if(_func_int != nullptr)
521 {
522 execute_window_loop(window, [&](const Coordinates & id)
523 {
524 (*_func_int)(input1.ptr(), input2.ptr(), output.ptr(), _scale_exponent);
525 },
526 input1, input2, output);
527 }
528 else if(_func_q_int != nullptr)
529 {
530 int fixed_point_position = _input1->info()->fixed_point_position();
531 execute_window_loop(window, [&](const Coordinates & id)
532 {
533 (*_func_q_int)(input1.ptr(), input2.ptr(), output.ptr(), _scale_exponent, fixed_point_position);
534 },
535 input1, input2, output);
536 }
537 else
538 {
539 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
540 execute_window_loop(window, [&](const Coordinates & id)
541 {
542 (*_func_float)(input1.ptr(), input2.ptr(), output.ptr(), _scale);
543 },
544 input1, input2, output);
545 }
546}