blob: 150db39695861e417804aad1151a4b99171fe803 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/TensorInfo.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h"
34
35#include <arm_neon.h>
36#include <climits>
37#include <cmath>
38#include <cstdint>
39#include <cstdlib>
40
Pablo Tellodf246182017-07-03 16:25:09 +010041#if ARM_COMPUTE_ENABLE_FP16
42#include <arm_fp16.h> // needed for float16_t
43#endif /* ARM_COMPUTE_ENABLE_FP16 */
44
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045using namespace arm_compute;
46
47namespace arm_compute
48{
49class Coordinates;
50} // namespace arm_compute
51
52namespace
53{
54const float scale255_constant = 1.f / 255.f;
55const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
56const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
57
58/* Scales a given vector by 1/255.
59 *
60 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
61 *
62 * @param in Input vector to scale.
63 * @return Scaled output rounded to nearest (round half up).
64 */
65inline int32x4_t scale255_S32_S32(int32x4_t in)
66{
67 // Scale
68 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
69 // Round to nearest (round half up)
70 // Add +0.5 for all values
71 // Afterwards vcvt rounds toward zero
72 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
73}
74
75inline uint16x8_t scale255_U16_U16(uint16x8_t in)
76{
77 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
78 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
79 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
80}
81
82template <bool is_scale255, bool is_sat>
83void mul_U8_U8_U8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
84{
85 const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
86 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
87 const auto output = static_cast<uint8_t *__restrict>(output_ptr);
88
89 const uint8x16_t ta1 = vld1q_u8(input1);
90 const uint8x16_t ta2 = vld1q_u8(input2);
91
92 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
93 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
94 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
95 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
96
97 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
98 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
99
100 if(is_scale255)
101 {
102 tmp1_high = scale255_U16_U16(tmp1_high);
103 tmp1_low = scale255_U16_U16(tmp1_low);
104 }
105 else
106 {
107 const int16x8_t vn = vdupq_n_s16(-n);
108
109 if(is_sat)
110 {
111 tmp1_high = vqshlq_u16(tmp1_high, vn);
112 tmp1_low = vqshlq_u16(tmp1_low, vn);
113 }
114 else
115 {
116 tmp1_high = vshlq_u16(tmp1_high, vn);
117 tmp1_low = vshlq_u16(tmp1_low, vn);
118 }
119 }
120
121 if(is_sat)
122 {
123 vst1q_u8(output, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
124 }
125 else
126 {
127 vst1q_u8(output, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
128 }
129}
130
131template <bool is_scale255, bool is_sat>
132void mul_QS8_QS8_QS8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
133{
134 // n is the exponent of the scaling factor, that is scale = 1/2^n. Currently, we only support scaling factor equal to 1 => n = 0.
135 ARM_COMPUTE_ERROR_ON_MSG(n != 0, "Scaling factor different than 1 not supported for 8-bit fixed-point pixel-wise multiplication");
136 ARM_COMPUTE_UNUSED(n);
137
138 const auto input1 = static_cast<const qint8_t *__restrict>(input1_ptr);
139 const auto input2 = static_cast<const qint8_t *__restrict>(input2_ptr);
140 const auto output = static_cast<qint8_t *__restrict>(output_ptr);
141
142 const qint8x16_t ta1 = vld1q_qs8(input1);
143 const qint8x16_t ta2 = vld1q_qs8(input2);
144
145 qint8x16_t res = (is_sat) ? vqmulq_qs8(ta1, ta2, fixed_point_position) : vmulq_qs8(ta1, ta2, fixed_point_position);
146
147 vst1q_s8(output, res);
148}
149
150template <bool is_scale255, bool is_sat>
Michele Di Giorgio81f0d152017-07-11 15:00:52 +0100151void mul_QS16_QS16_QS16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
152{
153 // n is the exponent of the scaling factor, that is scale = 1/2^n. Currently, we only support scaling factor equal to 1 => n = 0.
154 ARM_COMPUTE_ERROR_ON_MSG(n != 0, "Scaling factor different than 1 not supported for 16-bit fixed-point pixel-wise multiplication");
155 ARM_COMPUTE_UNUSED(n);
156
157 const qint16x8x2_t ta1 = vld2q_qs16(static_cast<const qint16_t *__restrict>(input1_ptr));
158 const qint16x8x2_t ta2 = vld2q_qs16(static_cast<const qint16_t *__restrict>(input2_ptr));
159
160 if(is_sat)
161 {
162 const qint16x8x2_t res =
163 {
164 {
165 // First 8 elements
166 vqmulq_qs16(ta1.val[0], ta2.val[0], fixed_point_position),
167 // Second 8 elements
168 vqmulq_qs16(ta1.val[1], ta2.val[1], fixed_point_position)
169 }
170 };
171
172 vst2q_s16(static_cast<qint16_t *__restrict>(output_ptr), res);
173 }
174 else
175 {
176 const qint16x8x2_t res =
177 {
178 {
179 // First 8 elements
180 vmulq_qs16(ta1.val[0], ta2.val[0], fixed_point_position),
181 // Second 8 elements
182 vmulq_qs16(ta1.val[1], ta2.val[1], fixed_point_position)
183 }
184 };
185
186 vst2q_s16(static_cast<qint16_t *__restrict>(output_ptr), res);
187 }
188}
189
190template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100191inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
192{
193 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
194 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
195 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
196 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
197
198 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
199 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
200
201 if(is_scale255)
202 {
203 tmp1_high = scale255_S32_S32(tmp1_high);
204 tmp1_low = scale255_S32_S32(tmp1_low);
205 }
206 else
207 {
208 // Right shift amount
209 const int32x4_t vn = vdupq_n_s32(-n);
210 // Left shift amount
211 const int32x4_t vnl = vdupq_n_s32(n);
212 // Calculate conversion bit
213 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
214 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
215 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
216 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
217 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
218 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
219 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
220 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
221 if(is_sat)
222 {
223 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
224 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
225 }
226 else
227 {
228 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
229 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
230 }
231 }
232
233 if(is_sat)
234 {
235 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
236 }
237 else
238 {
239 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
240 }
241}
242
243template <bool is_scale255, bool is_sat>
244inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
245{
246 const int16x8x2_t result =
247 {
248 {
249 // First 8 elements
250 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
251 // Second 8 elements
252 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
253 }
254 };
255
256 return result;
257}
258
259template <bool is_scale255, bool is_sat>
260void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
261{
262 const auto input1 = static_cast<const int16_t *__restrict>(input1_ptr);
263 const auto input2 = static_cast<const int16_t *__restrict>(input2_ptr);
264 const auto output = static_cast<int16_t *__restrict>(output_ptr);
265
266 const int16x8x2_t ta1 = vld2q_s16(input1);
267 const int16x8x2_t ta2 = vld2q_s16(input2);
268 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
269
270 vst2q_s16(output, result);
271}
272
273template <bool is_scale255, bool is_sat>
274void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
275{
276 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
277 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
278 const auto output = static_cast<float *__restrict>(output_ptr);
279
280 const float32x4x4_t ta1 = vld4q_f32(input1);
281 const float32x4x4_t ta2 = vld4q_f32(input2);
282 const float32x4_t scale_vec = vdupq_n_f32(scale);
283 const float32x4x4_t result =
284 {
285 {
286 vmulq_f32(vmulq_f32(ta1.val[0], ta2.val[0]), scale_vec),
287 vmulq_f32(vmulq_f32(ta1.val[1], ta2.val[1]), scale_vec),
288 vmulq_f32(vmulq_f32(ta1.val[2], ta2.val[2]), scale_vec),
289 vmulq_f32(vmulq_f32(ta1.val[3], ta2.val[3]), scale_vec)
290 }
291 };
292 vst4q_f32(output, result);
293}
294
295template <bool is_scale255, bool is_sat>
Pablo Tellodf246182017-07-03 16:25:09 +0100296void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
297{
298 ARM_COMPUTE_UNUSED(input1_ptr);
299 ARM_COMPUTE_UNUSED(input2_ptr);
300 ARM_COMPUTE_UNUSED(output_ptr);
301#ifdef ARM_COMPUTE_ENABLE_FP16
302 const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr);
303 const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr);
304 const auto output = static_cast<float16_t *__restrict>(output_ptr);
305 const float16x8x2_t ta1 = vld2q_f16(input1);
306 const float16x8x2_t ta2 = vld2q_f16(input2);
307 const float16x8_t scale_vec = vdupq_n_f16(scale);
308 const float16x8x2_t result =
309 {
310 {
311 vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
312 vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
313 }
314 };
315 vst2q_f16(output, result);
316#else /* ARM_COMPUTE_ENABLE_FP16 */
317 ARM_COMPUTE_ERROR("Not supported. Recompile the library with arch=arm64-v8.2-a.");
318#endif /* ARM_COMPUTE_ENABLE_FP16 */
319}
320
321template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100322void mul_U8_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
323{
324 const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
325 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
326 const auto output = static_cast<int16_t *__restrict>(output_ptr);
327
328 const uint8x16_t bv = vld1q_u8(input2);
329 const uint8x16_t av = vld1q_u8(input1);
330
331 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
332 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
333 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
334 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
335
336 if(is_scale255)
337 {
338 tmp_low = scale255_U16_U16(tmp_low);
339 tmp_high = scale255_U16_U16(tmp_high);
340 }
341 else
342 {
343 const int16x8_t vn = vdupq_n_s16(-n);
344
345 if(is_sat)
346 {
347 tmp_low = vqshlq_u16(tmp_low, vn);
348 tmp_high = vqshlq_u16(tmp_high, vn);
349 }
350 else
351 {
352 tmp_low = vshlq_u16(tmp_low, vn);
353 tmp_high = vshlq_u16(tmp_high, vn);
354 }
355 }
356
357 if(is_sat)
358 {
359 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
360
361 tmp_low = vminq_u16(tmp_low, max);
362 tmp_high = vminq_u16(tmp_high, max);
363 }
364
365 vst1q_s16(output, vreinterpretq_s16_u16(tmp_low));
366 vst1q_s16(output + 8, vreinterpretq_s16_u16(tmp_high));
367}
368
369template <bool is_scale255, bool is_sat>
370void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
371{
372 const auto input1 = static_cast<const int16_t *__restrict>(input1_ptr);
373 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
374 const auto output = static_cast<int16_t *__restrict>(output_ptr);
375
376 const int16x8x2_t ta1 = vld2q_s16(input1);
377 const uint8x8x2_t ta2u = vld2_u8(input2);
378 const int16x8x2_t ta2 =
379 {
380 {
381 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
382 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
383 }
384 };
385
386 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
387
388 vst2q_s16(output, result);
389}
390
391template <bool is_scale255, bool is_sat>
392void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
393{
394 // Simply swap the two input buffers
395 mul_S16_U8_S16_n<is_scale255, is_sat>(input2_ptr, input1_ptr, output_ptr, n);
396}
397} // namespace
398
399NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
400 : _func_float(nullptr), _func_int(nullptr), _func_q_int(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
401{
402}
403
404void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
405{
Georgios Pinitasf0dea702017-07-03 18:17:28 +0100406 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
407
408 // Auto initialize output if not initialized
409 {
410 set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
411
412 if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
413 {
414 set_format_if_unknown(*output->info(), Format::S16);
415 }
416 else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
417 {
418 set_format_if_unknown(*output->info(), Format::F32);
419 }
Pablo Tellodf246182017-07-03 16:25:09 +0100420 else if(input1->info()->data_type() == DataType::F16 || input2->info()->data_type() == DataType::F16)
421 {
422 set_format_if_unknown(*output->info(), Format::F16);
423 }
Georgios Pinitasf0dea702017-07-03 18:17:28 +0100424 else if(input1->info()->data_type() == DataType::QS8 && input2->info()->data_type() == DataType::QS8)
425 {
426 set_data_type_if_unknown(*output->info(), DataType::QS8);
427 set_fixed_point_position_if_zero(*output->info(), input1->info()->fixed_point_position());
428 }
429 }
430
431 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
Michele Di Giorgio81f0d152017-07-11 15:00:52 +0100432 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
433 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
434 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100435 ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
436 "Output can only be U8 if both inputs are U8");
Michele Di Giorgio81f0d152017-07-11 15:00:52 +0100437 if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100438 {
Michele Di Giorgio81f0d152017-07-11 15:00:52 +0100439 // Check that all data types are the same and all fixed-point positions are the same
440 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100441 }
442
443 _input1 = input1;
444 _input2 = input2;
445 _output = output;
446 _scale = scale;
447 _scale_exponent = 0;
448 _func_int = nullptr;
449 _func_q_int = nullptr;
450 _func_float = nullptr;
451
452 bool is_scale_255 = false;
453 // Check and validate scaling factor
454 if(std::abs(scale - scale255_constant) < 0.00001f)
455 {
456 ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
457 ARM_COMPUTE_UNUSED(rounding_policy);
458
459 is_scale_255 = true;
460 }
461 else
462 {
463 ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
464 ARM_COMPUTE_UNUSED(rounding_policy);
465
466 int exponent = 0;
467 const float normalized_mantissa = std::frexp(scale, &exponent);
468
469 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
470 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
471 // Moreover, it will be negative as we deal with 1/2^n
472 if((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1))
473 {
474 // Store the positive exponent. We know that we compute 1/2^n
475 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
476 _scale_exponent = std::abs(exponent - 1);
477 }
478 else
479 {
480 ARM_COMPUTE_ERROR("Scale value not supported (Should be 1/(2^n) or 1/255");
481 }
482 }
483
484 const DataType dt_input1 = input1->info()->data_type();
485 const DataType dt_input2 = input2->info()->data_type();
486 const DataType dt_output = output->info()->data_type();
487 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
488
489 if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output)
490 {
491 if(is_scale_255)
492 {
493 _func_int = is_sat ? &mul_U8_U8_U8_n<true, true> : &mul_U8_U8_U8_n<true, false>;
494 }
495 else
496 {
497 _func_int = is_sat ? &mul_U8_U8_U8_n<false, true> : &mul_U8_U8_U8_n<false, false>;
498 }
499 }
500 else if(DataType::S16 == dt_input1 && DataType::S16 == dt_input2 && DataType::S16 == dt_output)
501 {
502 if(is_scale_255)
503 {
504 _func_int = is_sat ? &mul_S16_S16_S16_n<true, true> : &mul_S16_S16_S16_n<true, false>;
505 }
506 else
507 {
508 _func_int = is_sat ? &mul_S16_S16_S16_n<false, true> : &mul_S16_S16_S16_n<false, false>;
509 }
510 }
511 else if(DataType::S16 == dt_input1 && DataType::U8 == dt_input2 && DataType::S16 == dt_output)
512 {
513 if(is_scale_255)
514 {
515 _func_int = is_sat ? &mul_S16_U8_S16_n<true, true> : &mul_S16_U8_S16_n<true, false>;
516 }
517 else
518 {
519 _func_int = is_sat ? &mul_S16_U8_S16_n<false, true> : &mul_S16_U8_S16_n<false, false>;
520 }
521 }
522 else if(DataType::U8 == dt_input1 && DataType::S16 == dt_input2 && DataType::S16 == dt_output)
523 {
524 if(is_scale_255)
525 {
526 _func_int = is_sat ? &mul_U8_S16_S16_n<true, true> : &mul_U8_S16_S16_n<true, false>;
527 }
528 else
529 {
530 _func_int = is_sat ? &mul_U8_S16_S16_n<false, true> : &mul_U8_S16_S16_n<false, false>;
531 }
532 }
533 else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::S16 == dt_output)
534 {
535 if(is_scale_255)
536 {
537 _func_int = is_sat ? &mul_U8_U8_S16_n<true, true> : &mul_U8_U8_S16_n<true, false>;
538 }
539 else
540 {
541 _func_int = is_sat ? &mul_U8_U8_S16_n<false, true> : &mul_U8_U8_S16_n<false, false>;
542 }
543 }
544 else if(DataType::QS8 == dt_input1 && DataType::QS8 == dt_input2 && DataType::QS8 == dt_output)
545 {
546 if(is_scale_255)
547 {
548 _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<true, true> : &mul_QS8_QS8_QS8_n<true, false>;
549 }
550 else
551 {
552 _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<false, true> : &mul_QS8_QS8_QS8_n<false, false>;
553 }
554 }
Michele Di Giorgio81f0d152017-07-11 15:00:52 +0100555 else if(DataType::QS16 == dt_input1 && DataType::QS16 == dt_input2 && DataType::QS16 == dt_output)
556 {
557 if(is_scale_255)
558 {
559 _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<true, true> : &mul_QS16_QS16_QS16_n<true, false>;
560 }
561 else
562 {
563 _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<false, true> : &mul_QS16_QS16_QS16_n<false, false>;
564 }
565 }
Pablo Tellodf246182017-07-03 16:25:09 +0100566 else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output)
567 {
568 _func_float = &mul_F16_F16_F16_n<false, false>;
569 _func_int = nullptr;
570 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100571 else if(DataType::F32 == dt_input1 && DataType::F32 == dt_input2 && DataType::F32 == dt_output)
572 {
573 _func_float = &mul_F32_F32_F32_n<false, false>;
574 _func_int = nullptr;
575 }
576 else
577 {
578 ARM_COMPUTE_ERROR("You called with the wrong img formats");
579 }
580
581 constexpr unsigned int num_elems_processed_per_iteration = 16;
582
583 // Configure kernel window
584 Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
585 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
586
587 update_window_and_padding(win,
588 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
589 AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
590 output_access);
591
592 ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
593 input2->info()->valid_region());
594
595 output_access.set_valid_region(win, valid_region);
596
597 INEKernel::configure(win);
598}
599
600void NEPixelWiseMultiplicationKernel::run(const Window &window)
601{
602 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
603 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
604
605 Iterator input1(_input1, window);
606 Iterator input2(_input2, window);
607 Iterator output(_output, window);
608
609 if(_func_int != nullptr)
610 {
611 execute_window_loop(window, [&](const Coordinates & id)
612 {
613 (*_func_int)(input1.ptr(), input2.ptr(), output.ptr(), _scale_exponent);
614 },
615 input1, input2, output);
616 }
617 else if(_func_q_int != nullptr)
618 {
619 int fixed_point_position = _input1->info()->fixed_point_position();
620 execute_window_loop(window, [&](const Coordinates & id)
621 {
622 (*_func_q_int)(input1.ptr(), input2.ptr(), output.ptr(), _scale_exponent, fixed_point_position);
623 },
624 input1, input2, output);
625 }
626 else
627 {
628 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
629 execute_window_loop(window, [&](const Coordinates & id)
630 {
631 (*_func_float)(input1.ptr(), input2.ptr(), output.ptr(), _scale);
632 },
633 input1, input2, output);
634 }
635}