blob: cd1c4b28cca6953e0ceb3b7e41e60e92296fe57e [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgio9428a182020-03-30 14:10:20 +01002 * Copyright (c) 2016-2020 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Manuel Bottini79fa9a22019-02-22 17:54:22 +000027#include "arm_compute/core/NEON/NEAsymm.h"
Manuel Bottini7bb56c62019-06-26 15:17:09 +010028#include "arm_compute/core/NEON/NESymm.h"
giuros01154bc1c2019-03-26 17:44:40 +000029#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031
32#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000034#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellodf246182017-07-03 16:25:09 +010035#include <arm_fp16.h> // needed for float16_t
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000036#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +010037
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038namespace arm_compute
39{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040namespace
41{
42const float scale255_constant = 1.f / 255.f;
43const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
44const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
45
Georgios Pinitas631c41a2017-12-06 11:53:03 +000046inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000047{
48 ARM_COMPUTE_UNUSED(overflow_policy);
49 ARM_COMPUTE_UNUSED(rounding_policy);
50
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
Pablo Tello52ea9c22019-12-10 11:28:53 +000052 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
53 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
Michele Di Giorgio9428a182020-03-30 14:10:20 +010054 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
55 DataType::S16, DataType::QSYMM16,
56 DataType::S32, DataType::F16, DataType::F32);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000057 if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
Pablo Tello52ea9c22019-12-10 11:28:53 +000058 {
59 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000060 ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
Pablo Tello52ea9c22019-12-10 11:28:53 +000061 }
Manuel Bottini79fa9a22019-02-22 17:54:22 +000062
63 if(output->total_size() > 0)
64 {
Manuel Bottini7bb56c62019-06-26 15:17:09 +010065 if(is_data_type_quantized(output->data_type()))
Manuel Bottini79fa9a22019-02-22 17:54:22 +000066 {
67 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output);
68 }
69
70 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
71 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
72 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
Michele Di Giorgio9428a182020-03-30 14:10:20 +010073
74 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
75 "Output can only be U8 if both inputs are U8");
76 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16),
77 "Output can only be S32 if both inputs are QSYMM16");
78 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
Manuel Bottini79fa9a22019-02-22 17:54:22 +000079 }
Michalis Spyrou861f0db2018-02-26 16:47:58 +000080
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000081 if(std::abs(scale - scale255_constant) < 0.00001f)
82 {
83 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
84 }
85 else
86 {
87 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
88
89 int exponent = 0;
90 const float normalized_mantissa = std::frexp(scale, &exponent);
91
92 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
93 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
94 // Moreover, it will be negative as we deal with 1/2^n
95 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
96 }
97
Georgios Pinitas631c41a2017-12-06 11:53:03 +000098 return Status{};
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000099}
100
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100101/* Scales a given vector by 1/255.
102 *
103 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
104 *
105 * @param in Input vector to scale.
106 * @return Scaled output rounded to nearest (round half up).
107 */
108inline int32x4_t scale255_S32_S32(int32x4_t in)
109{
110 // Scale
111 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
112 // Round to nearest (round half up)
113 // Add +0.5 for all values
114 // Afterwards vcvt rounds toward zero
115 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
116}
117
118inline uint16x8_t scale255_U16_U16(uint16x8_t in)
119{
120 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
121 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
122 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
123}
124
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100125template <typename T>
126inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
127vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000128{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100129 return vquantize_signed(val, info);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000130}
131
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100132template <typename T>
133inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
134vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Pablo Tello52ea9c22019-12-10 11:28:53 +0000135{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100136 return vquantize(val, info);
Pablo Tello52ea9c22019-12-10 11:28:53 +0000137}
138
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100139template <typename T>
140inline typename std::enable_if<std::is_same<T, int8_t>::value, int8_t>::type
141quantize(float val, const UniformQuantizationInfo &info)
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100142{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100143 int32_t tmp = static_cast<int32_t>(val / info.scale) + info.offset;
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100144
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100145 T tmp_qua = static_cast<T>(tmp > SCHAR_MAX) ? SCHAR_MAX : ((tmp < SCHAR_MIN) ? SCHAR_MIN : tmp);
146 return tmp_qua;
147}
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100148
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100149template <typename T>
150inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8_t>::type
151quantize(float val, const UniformQuantizationInfo &info)
152{
153 int32_t tmp = static_cast<int32_t>(val / info.scale) + info.offset;
154
155 T tmp_qua = static_cast<T>((tmp > UCHAR_MAX) ? UCHAR_MAX : tmp);
156 return tmp_qua;
157}
158
159template <typename T>
160inline float dequantize(const T *input, const UniformQuantizationInfo &info)
161{
162 return static_cast<float>((*input) - info.offset) * info.scale;
163}
164
165template <typename T>
166void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
167{
168 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
169 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
170 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
171
172 // Create input windows
173 Window win = window;
174 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
175 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
176
177 // Clear X Dimension on execution window as we handle manually
178 win.set(Window::DimX, Window::Dimension(0, 1, 1));
179 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
180 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
181
182 Iterator input1(in1, input1_win);
183 Iterator input2(in2, input2_win);
184 Iterator output(out, win);
185
186 const int window_step_x = 16 / sizeof(T);
187 const auto window_start_x = static_cast<int>(window.x().start());
188 const auto window_end_x = static_cast<int>(window.x().end());
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100189
190 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
191
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100192 execute_window_loop(win, [&](const Coordinates &)
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100193 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100194 const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
195 const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
196 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100197
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100198 // Compute window_step_x elements per iteration
199 int x = window_start_x;
200 for(; x <= (window_end_x - window_step_x); x += window_step_x)
201 {
202 const auto input1_q = wrapper::vloadq(input1_ptr + x);
203 const auto input2_q = wrapper::vloadq(input2_ptr + x);
204
205 // Dequantize inputs
206 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
207 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
208
209 const float32x4x4_t out_f32x4x4 =
210 {
211 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
212 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
213 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
214 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
215 };
216
217 // Quantize output
218 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
219 wrapper::vstore(output_ptr + x, result);
220 }
221
222 // Compute left-over elements
223 for(; x < window_end_x; ++x)
224 {
225 // Dequantize inputs
226 float tmp_in1 = dequantize(input1_ptr + x, input1_qua_info);
227 float tmp_in2 = dequantize(input2_ptr + x, input2_qua_info);
228 float tmp_f = tmp_in1 * tmp_in2;
229
230 // Quantize output
231 const auto tmp_qua = quantize<T>(tmp_f, tmp_qua_info);
232 *(output_ptr + x) = tmp_qua;
233 }
234 },
235 input1, input2, output);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100236}
237
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100238void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
239{
240 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
241 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
242 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
243
244 // Create input windows
245 Window win = window;
246 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
247 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
248
249 // Clear X Dimension on execution window as we handle manually
250 win.set(Window::DimX, Window::Dimension(0, 1, 1));
251 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
252 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
253
254 Iterator input1(in1, input1_win);
255 Iterator input2(in2, input2_win);
256 Iterator output(out, win);
257
258 const int window_step_x = 16;
259 const auto window_start_x = static_cast<int>(window.x().start());
260 const auto window_end_x = static_cast<int>(window.x().end());
261
262 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
263
264 execute_window_loop(win, [&](const Coordinates &)
265 {
266 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
267 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
268 const auto output_ptr = reinterpret_cast<qsymm16_t *>(output.ptr());
269
270 // Compute window_step_x elements per iteration
271 int x = window_start_x;
272 for(; x <= (window_end_x - window_step_x); x += window_step_x)
273 {
274 const qsymm16x8x2_t input1_q =
275 {
276 {
277 vld1q_s16(input1_ptr + x),
278 vld1q_s16(input1_ptr + x + 8),
279 }
280 };
281 const qsymm16x8x2_t input2_q =
282 {
283 {
284 vld1q_s16(input2_ptr + x),
285 vld1q_s16(input2_ptr + x + 8),
286 }
287 };
288
289 // Dequantize inputs
290 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
291 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
292
293 const float32x4x4_t out_f32x4x4 =
294 {
295 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
296 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
297 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
298 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
299 };
300
301 const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
302 vst1q_s16(output_ptr + x, result.val[0]);
303 vst1q_s16(output_ptr + x + 8, result.val[1]);
304 }
305
306 // Compute left-over elements
307 for(; x < window_end_x; ++x)
308 {
309 // Dequantize inputs
310 float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
311 float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
312 float tmp_f = tmp_in1 * tmp_in2;
313
314 // Quantize output, lrintf() has same rounding mode as vcombine_s16
315 int32_t tmp = lrintf(tmp_f / tmp_qua_info.scale);
316 qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
317 *(output_ptr + x) = tmp_qua;
318 }
319 },
320 input1, input2, output);
321}
322
323void mul_QSYMM16_QSYMM16_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int scale)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100324{
325 ARM_COMPUTE_UNUSED(scale);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100326
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100327 // Create input windows
328 Window win = window;
329 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
330 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100331
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100332 // Clear X Dimension on execution window as we handle manually
333 win.set(Window::DimX, Window::Dimension(0, 1, 1));
334 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
335 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100336
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100337 Iterator input1(in1, input1_win);
338 Iterator input2(in2, input2_win);
339 Iterator output(out, win);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100340
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100341 const int window_step_x = 16;
342 const auto window_start_x = static_cast<int>(window.x().start());
343 const auto window_end_x = static_cast<int>(window.x().end());
344
345 execute_window_loop(win, [&](const Coordinates &)
346 {
347 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
348 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
349 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
350
351 // Compute window_step_x elements per iteration
352 int x = window_start_x;
353 for(; x <= (window_end_x - window_step_x); x += window_step_x)
354 {
355 const qsymm16x8x2_t input1_q =
356 {
357 {
358 vld1q_s16(input1_ptr + x),
359 vld1q_s16(input1_ptr + x + 8),
360 }
361 };
362 const qsymm16x8x2_t input2_q =
363 {
364 {
365 vld1q_s16(input2_ptr + x),
366 vld1q_s16(input2_ptr + x + 8),
367 }
368 };
369
370 const int32x4x4_t in1_s32 =
371 {
372 {
373 vmovl_s16(vget_low_s16(input1_q.val[0])),
374 vmovl_s16(vget_high_s16(input1_q.val[0])),
375 vmovl_s16(vget_low_s16(input1_q.val[1])),
376 vmovl_s16(vget_high_s16(input1_q.val[1])),
377 }
378 };
379 const int32x4x4_t in2_s32 =
380 {
381 {
382 vmovl_s16(vget_low_s16(input2_q.val[0])),
383 vmovl_s16(vget_high_s16(input2_q.val[0])),
384 vmovl_s16(vget_low_s16(input2_q.val[1])),
385 vmovl_s16(vget_high_s16(input2_q.val[1])),
386 }
387 };
388
389 const int32x4x4_t result =
390 {
391 {
392 vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
393 vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
394 vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
395 vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
396 }
397 };
398
399 vst1q_s32(output_ptr + x, result.val[0]);
400 vst1q_s32(output_ptr + x + 4, result.val[1]);
401 vst1q_s32(output_ptr + x + 8, result.val[2]);
402 vst1q_s32(output_ptr + x + 12, result.val[3]);
403 }
404
405 // Compute left-over elements
406 for(; x < window_end_x; ++x)
407 {
408 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
409 *(output_ptr + x) = tmp;
410 }
411 },
412 input1, input2, output);
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100413}
414
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100415template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100416void mul_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100417{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100418 // Create input windows
419 Window win = window;
420 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
421 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100422
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100423 // Clear X Dimension on execution window as we handle manually
424 win.set(Window::DimX, Window::Dimension(0, 1, 1));
425 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
426 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100427
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100428 Iterator input1(in1, input1_win);
429 Iterator input2(in2, input2_win);
430 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100431
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100432 const int window_step_x = 16 / sizeof(uint8_t);
433 const auto window_start_x = static_cast<int>(window.x().start());
434 const auto window_end_x = static_cast<int>(window.x().end());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100435
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100436 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100437 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100438 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
439 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
440 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100441
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100442 // Compute window_step_x elements per iteration
443 int x = window_start_x;
444 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100445 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100446 const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
447 const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100448
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100449 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
450 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
451 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
452 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
453
454 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
455 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
456
457 if(is_scale255)
458 {
459 tmp1_high = scale255_U16_U16(tmp1_high);
460 tmp1_low = scale255_U16_U16(tmp1_low);
461 }
462 else
463 {
464 const int16x8_t vn = vdupq_n_s16(-n);
465
466 if(is_sat)
467 {
468 tmp1_high = vqshlq_u16(tmp1_high, vn);
469 tmp1_low = vqshlq_u16(tmp1_low, vn);
470 }
471 else
472 {
473 tmp1_high = vshlq_u16(tmp1_high, vn);
474 tmp1_low = vshlq_u16(tmp1_low, vn);
475 }
476 }
477 if(is_sat)
478 {
479 vst1q_u8(output_ptr, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
480 }
481 else
482 {
483 vst1q_u8(output_ptr, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
484 }
485 }
486
487 // Compute left-over elements
488 for(; x < window_end_x; ++x)
489 {
490 uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
491
492 if(is_scale255)
493 {
494 float tmp_f = static_cast<float>(tmp) * scale255_constant;
495 tmp = static_cast<uint16_t>(tmp_f + 0.5f);
496 }
497 else
498 {
499 tmp >>= n;
500 }
501 if(is_sat && tmp > 255)
502 {
503 tmp = 255;
504 }
505 *(output_ptr + x) = static_cast<uint8_t>(tmp);
506 }
507 },
508 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100509}
510
511template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
513{
514 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
515 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
516 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
517 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
518
519 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
520 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
521
522 if(is_scale255)
523 {
524 tmp1_high = scale255_S32_S32(tmp1_high);
525 tmp1_low = scale255_S32_S32(tmp1_low);
526 }
527 else
528 {
529 // Right shift amount
530 const int32x4_t vn = vdupq_n_s32(-n);
531 // Left shift amount
532 const int32x4_t vnl = vdupq_n_s32(n);
533 // Calculate conversion bit
534 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
535 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
536 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
537 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
538 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
539 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
540 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
541 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
542 if(is_sat)
543 {
544 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
545 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
546 }
547 else
548 {
549 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
550 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
551 }
552 }
553
554 if(is_sat)
555 {
556 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
557 }
558 else
559 {
560 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
561 }
562}
563
564template <bool is_scale255, bool is_sat>
565inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
566{
567 const int16x8x2_t result =
568 {
569 {
570 // First 8 elements
571 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
572 // Second 8 elements
573 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
574 }
575 };
576
577 return result;
578}
579
580template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100581void mul_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100582{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100583 // Create input windows
584 Window win = window;
585 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
586 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100587
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100588 // Clear X Dimension on execution window as we handle manually
589 win.set(Window::DimX, Window::Dimension(0, 1, 1));
590 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
591 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100592
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100593 Iterator input1(in1, input1_win);
594 Iterator input2(in2, input2_win);
595 Iterator output(out, win);
596
597 const int window_step_x = 16;
598 const auto window_start_x = static_cast<int>(window.x().start());
599 const auto window_end_x = static_cast<int>(window.x().end());
600
601 execute_window_loop(win, [&](const Coordinates &)
602 {
603 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
604 const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
605 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
606
607 // Compute window_step_x elements per iteration
608 int x = window_start_x;
609 for(; x <= (window_end_x - window_step_x); x += window_step_x)
610 {
611 const int16x8x2_t ta1 =
612 {
613 {
614 vld1q_s16(input1_ptr + x),
615 vld1q_s16(input1_ptr + x + 8),
616 }
617 };
618 const int16x8x2_t ta2 =
619 {
620 {
621 vld1q_s16(input2_ptr + x),
622 vld1q_s16(input2_ptr + x + 8),
623 }
624 };
625 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
626
627 vst1q_s16(output_ptr + x, result.val[0]);
628 vst1q_s16(output_ptr + x + 8, result.val[1]);
629 }
630
631 // Compute left-over elements
632 for(; x < window_end_x; ++x)
633 {
634 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
635
636 if(is_scale255)
637 {
638 float tmp_f = static_cast<float>(tmp) * scale255_constant;
639
640 tmp = static_cast<int32_t>(tmp_f + 0.5f);
641 }
642 else
643 {
644 if(tmp >= 0)
645 {
646 tmp >>= n;
647 }
648 else
649 {
650 uint32_t mask = (1u << n) - 1;
651 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
652 }
653 }
654 if(is_sat)
655 {
656 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
657 }
658 *(output_ptr + x) = static_cast<int16_t>(tmp);
659 }
660 },
661 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100662}
663
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100664void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100665{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100666 // Create input windows
667 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
668 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100669
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100670 // Clear X Dimension on execution window as we handle manually
671 Window win = window;
672 win.set(Window::DimX, Window::Dimension(0, 1, 1));
673
674 constexpr int window_step_x = 16 / sizeof(float);
675 const auto window_start_x = static_cast<int>(window.x().start());
676 const auto window_end_x = static_cast<int>(window.x().end());
677 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
678
679 Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
680 Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
681 Iterator output(out, window);
682
683 using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
684
685 if(is_broadcast_across_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100686 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100687 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
688 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
689 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
690 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
691 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
692
693 // Clear X Dimension on execution window as we handle manually
694 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
695
696 Iterator broadcast_input(broadcast_tensor, broadcast_win);
697 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
698 Iterator output(out, win);
699
700 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100701 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100702 const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
703 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
704
705 const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
706 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
707 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
708
709 // Compute window_step_x elements per iteration
710 int x = window_start_x;
711 for(; x <= (window_end_x - window_step_x); x += window_step_x)
712 {
713 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
714 auto res = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
715 wrapper::vstore(output_ptr + x, res);
716 }
717
718 // Compute left-over elements
719 for(; x < window_end_x; ++x)
720 {
721 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
722 *(output_ptr + x) = broadcast_value * non_broadcast_v * scale;
723 }
724 },
725 broadcast_input, non_broadcast_input, output);
726 }
727 else
728 {
729 // Clear X Dimension on execution window as we handle manually
730 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
731 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
732
733 Iterator input1(in1, input1_win);
734 Iterator input2(in2, input2_win);
735 Iterator output(out, win);
736
737 execute_window_loop(win, [&](const Coordinates &)
738 {
739 const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
740 const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
741 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
742
743 // Compute window_step_x elements per iteration
744 int x = window_start_x;
745 for(; x <= (window_end_x - window_step_x); x += window_step_x)
746 {
747 const auto ta1 = wrapper::vloadq(input1_ptr + x);
748 const auto ta2 = wrapper::vloadq(input2_ptr + x);
749 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
750 const auto res = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
751 wrapper::vstore(output_ptr + x, res);
752 }
753
754 // Compute left-over elements
755 for(; x < window_end_x; ++x)
756 {
757 const auto ta1 = *(input1_ptr + x);
758 const auto ta2 = *(input2_ptr + x);
759 *(output_ptr + x) = ta1 * ta2 * scale;
760 }
761 },
762 input1, input2, output);
763 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100764}
765
giuros01154bc1c2019-03-26 17:44:40 +0000766void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr)
767{
768 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
769 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
770 const auto output = static_cast<float *__restrict>(output_ptr);
771
772 const float32x4_t a = wrapper::vloadq(input1);
773 float32x4_t b = wrapper::vloadq(input2);
774
775 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
776
777 const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
778 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
779 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
780 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
781 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
782
783 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
784 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
785
786 float32x4_t res = wrapper::vmul(tmp0, b);
787
788 b = wrapper::vrev64(b);
789 b = wrapper::vmul(b, mask);
790
791 res = wrapper::vmla(res, tmp1, b);
792 wrapper::vstore(output, res);
793}
794
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000795#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100796void mul_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
797{
798 // Create input windows
799 Window win = window;
800 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
801 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
802
803 // Clear X Dimension on execution window as we handle manually
804 win.set(Window::DimX, Window::Dimension(0, 1, 1));
805 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
806 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
807
808 Iterator input1(in1, input1_win);
809 Iterator input2(in2, input2_win);
810 Iterator output(out, win);
811
812 const int window_step_x = 16;
813 const auto window_start_x = static_cast<int>(window.x().start());
814 const auto window_end_x = static_cast<int>(window.x().end());
815
816 execute_window_loop(win, [&](const Coordinates &)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100817 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100818 const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
819 const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
820 const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
821
822 // Compute window_step_x elements per iteration
823 int x = window_start_x;
824 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100825 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100826 const float16x8x2_t ta1 =
827 {
828 {
829 vld1q_f16(input1_ptr + x),
830 vld1q_f16(input1_ptr + x + 8),
831 }
832 };
833 const float16x8x2_t ta2 =
834 {
835 {
836 vld1q_f16(input2_ptr + x),
837 vld1q_f16(input2_ptr + x + 8),
838 }
839 };
840 const float16x8_t scale_vec = vdupq_n_f16(scale);
841 const float16x8x2_t result =
842 {
843 {
844 vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
845 vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
846 }
847 };
848 vst1q_f16(output_ptr + x, result.val[0]);
849 vst1q_f16(output_ptr + x + 8, result.val[1]);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100850 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100851
852 // Compute left-over elements
853 for(; x < window_end_x; ++x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100854 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100855 const auto ta1 = *(input1_ptr + x);
856 const auto ta2 = *(input2_ptr + x);
857 *(output_ptr + x) = ta1 * ta2 * scale;
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100858 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100859 },
860 input1, input2, output);
861}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000862#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +0100863
864template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100865void mul_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100866{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100867 // Create input windows
868 Window win = window;
869 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
870 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100871
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100872 // Clear X Dimension on execution window as we handle manually
873 win.set(Window::DimX, Window::Dimension(0, 1, 1));
874 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
875 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100876
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100877 Iterator input1(in1, input1_win);
878 Iterator input2(in2, input2_win);
879 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100880
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100881 const int window_step_x = 16 / sizeof(uint8_t);
882 const auto window_start_x = static_cast<int>(window.x().start());
883 const auto window_end_x = static_cast<int>(window.x().end());
884
885 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100886 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100887 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
888 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
889 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100890
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100891 // Compute window_step_x elements per iteration
892 int x = window_start_x;
893 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100894 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100895 const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
896 const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
897
898 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
899 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
900 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
901 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
902
903 if(is_scale255)
904 {
905 tmp_low = scale255_U16_U16(tmp_low);
906 tmp_high = scale255_U16_U16(tmp_high);
907 }
908 else
909 {
910 const int16x8_t vn = vdupq_n_s16(-n);
911
912 if(is_sat)
913 {
914 tmp_low = vqshlq_u16(tmp_low, vn);
915 tmp_high = vqshlq_u16(tmp_high, vn);
916 }
917 else
918 {
919 tmp_low = vshlq_u16(tmp_low, vn);
920 tmp_high = vshlq_u16(tmp_high, vn);
921 }
922 }
923
924 if(is_sat)
925 {
926 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
927
928 tmp_low = vminq_u16(tmp_low, max);
929 tmp_high = vminq_u16(tmp_high, max);
930 }
931
932 vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
933 vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100934 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100935
936 // Compute left-over elements
937 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100938 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100939 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
940
941 if(is_scale255)
942 {
943 float tmp_f = static_cast<float>(tmp) * scale255_constant;
944 tmp = static_cast<int32_t>(tmp_f + 0.5f);
945 }
946 else
947 {
948 tmp >>= n;
949 }
950
951 if(is_sat)
952 {
953 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
954 }
955
956 *(output_ptr + x) = static_cast<int16_t>(tmp);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100957 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100958 },
959 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100960}
961
962template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100963void mul_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100964{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100965 // Create input windows
966 Window win = window;
967 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
968 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100969
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100970 // Clear X Dimension on execution window as we handle manually
971 win.set(Window::DimX, Window::Dimension(0, 1, 1));
972 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
973 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100974
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100975 Iterator input1(in1, input1_win);
976 Iterator input2(in2, input2_win);
977 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100978
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100979 const int window_step_x = 16;
980 const auto window_start_x = static_cast<int>(window.x().start());
981 const auto window_end_x = static_cast<int>(window.x().end());
982
983 execute_window_loop(win, [&](const Coordinates &)
984 {
985 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
986 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
987 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
988
989 // Compute window_step_x elements per iteration
990 int x = window_start_x;
991 for(; x <= (window_end_x - window_step_x); x += window_step_x)
992 {
993 const int16x8x2_t ta1 =
994 {
995 {
996 vld1q_s16(input1_ptr + x),
997 vld1q_s16(input1_ptr + x + 8),
998 }
999 };
1000 const uint8x8x2_t ta2u =
1001 {
1002 {
1003 vld1_u8(input2_ptr + x),
1004 vld1_u8(input2_ptr + x + 8),
1005 }
1006 };
1007 const int16x8x2_t ta2 =
1008 {
1009 {
1010 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1011 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1012 }
1013 };
1014
1015 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1016
1017 vst1q_s16(output_ptr + x, result.val[0]);
1018 vst1q_s16(output_ptr + x + 8, result.val[1]);
1019 }
1020
1021 // Compute left-over elements
1022 for(; x < window_end_x; ++x)
1023 {
1024 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1025
1026 if(is_scale255)
1027 {
1028 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1029
1030 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1031 }
1032 else
1033 {
1034 if(tmp >= 0)
1035 {
1036 tmp >>= n;
1037 }
1038 else
1039 {
1040 uint32_t mask = (1u << n) - 1;
1041 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
1042 }
1043 }
1044 if(is_sat)
1045 {
1046 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1047 }
1048 *(output_ptr + x) = static_cast<int16_t>(tmp);
1049 }
1050 },
1051 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001052}
1053
1054template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001055void mul_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001056{
1057 // Simply swap the two input buffers
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001058 mul_S16_U8_S16<is_scale255, is_sat>(in2, in1, out, window, n);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001059}
1060} // namespace
1061
1062NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001063 : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001064{
1065}
1066
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001067void NEPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001068{
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001069 ARM_COMPUTE_UNUSED(rounding_policy);
Georgios Pinitasf0dea702017-07-03 18:17:28 +01001070 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1071
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001072 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001073
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001074 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001075 const TensorShape &out_shape = broadcast_pair.first;
1076 const ValidRegion &valid_region = broadcast_pair.second;
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001077
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001078 // Auto initialize output if not initialized
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001079 set_shape_if_empty(*output, out_shape);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001080
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001081 _scale = scale;
1082 _scale_exponent = 0;
1083 _func_quantized = nullptr;
1084 _func_int = nullptr;
1085 _func_float = nullptr;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001086
1087 bool is_scale_255 = false;
1088 // Check and validate scaling factor
1089 if(std::abs(scale - scale255_constant) < 0.00001f)
1090 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001091 is_scale_255 = true;
1092 }
1093 else
1094 {
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001095 int exponent = 0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001096
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001097 std::frexp(scale, &exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001098
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001099 // Store the positive exponent. We know that we compute 1/2^n
1100 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1101 _scale_exponent = std::abs(exponent - 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001102 }
1103
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001104 const DataType dt_input1 = input1->data_type();
1105 const DataType dt_input2 = input2->data_type();
1106 const DataType dt_output = output->data_type();
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001107 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
1108
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001109 switch(dt_input1)
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001110 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001111 case DataType::QASYMM8:
1112 if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1113 {
1114 _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1115 }
1116 break;
1117 case DataType::QASYMM8_SIGNED:
1118 if(dt_input2 == DataType::QASYMM8_SIGNED)
1119 {
1120 _func_quantized = &mul_saturate_quantized_8<int8_t>;
1121 ;
1122 }
1123 break;
1124 case DataType::QSYMM16:
1125 if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1126 {
1127 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1128 }
1129 else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1130 {
1131 _func_int = &mul_QSYMM16_QSYMM16_S32;
1132 }
1133 break;
1134 case DataType::S16:
1135 if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1136 {
1137 if(is_scale_255)
1138 {
1139 _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1140 }
1141 else
1142 {
1143 _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1144 }
1145 }
1146 if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1147 {
1148 if(is_scale_255)
1149 {
1150 _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1151 }
1152 else
1153 {
1154 _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1155 }
1156 }
1157 break;
1158 case DataType::U8:
1159 if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1160 {
1161 if(is_scale_255)
1162 {
1163 _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1164 }
1165 else
1166 {
1167 _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1168 }
1169 }
1170 else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1171 {
1172 if(is_scale_255)
1173 {
1174 _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1175 }
1176 else
1177 {
1178 _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1179 }
1180 }
1181 else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1182 {
1183 if(is_scale_255)
1184 {
1185 _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1186 }
1187 else
1188 {
1189 _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1190 }
1191 }
1192 break;
1193#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1194 case DataType::F16:
1195 _func_float = &mul_F16_F16_F16;
1196 break;
1197#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1198 case DataType::F32:
1199 _func_float = &mul_F32_F32_F32;
1200 break;
1201 default:
1202 ARM_COMPUTE_ERROR("You called with the wrong img formats");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001203 }
1204
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001205 // Configure kernel window
1206 Coordinates coord;
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001207 coord.set_num_dimensions(output->num_dimensions());
1208 output->set_valid_region(valid_region);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001209 Window win = calculate_max_window(valid_region, Steps());
1210
1211 INEKernel::configure(win);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001212}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001213
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001214Status NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
1215 RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001216{
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001217 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001218 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001219
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001220 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001221}
1222
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001223void NEPixelWiseMultiplicationKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001224{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001225 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001226 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1227 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1228
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001229 auto input1 = inputs.at(TensorType::ACL_SRC_0);
1230 auto input2 = inputs.at(TensorType::ACL_SRC_1);
1231 auto output = outputs.at(TensorType::ACL_DST);
1232
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001233 if(_func_quantized != nullptr)
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001234 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001235 (*_func_quantized)(input1, input2, output, window, _scale);
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001236 }
1237 else if(_func_int != nullptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001238 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001239 (*_func_int)(input1, input2, output, window, _scale_exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001240 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001241 else
1242 {
1243 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001244 (*_func_float)(input1, input2, output, window, _scale);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001245 }
1246}
giuros01154bc1c2019-03-26 17:44:40 +00001247namespace
1248{
1249constexpr unsigned int num_elems_processed_per_iteration_complex = 2;
1250
1251Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1252{
1253 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32);
1254 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32);
1255
1256 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
1257
1258 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1259
1260 // Validate in case of configured output
1261 if(output->total_size() > 0)
1262 {
1263 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32);
1264 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
1265 }
1266
1267 return Status{};
1268}
1269
1270std::pair<Status, Window> validate_and_configure_window_complex(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
1271{
1272 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
1273 const TensorShape &out_shape = broadcast_pair.first;
1274 const ValidRegion &valid_region = broadcast_pair.second;
1275
1276 // Auto initialize output if not initialized
1277 const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type());
1278 auto_init_if_empty(*output, out_info);
1279
1280 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration_complex));
1281 Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
1282 Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
1283
1284 AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration_complex);
1285 AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_complex);
1286 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_complex);
1287
1288 bool window_changed = update_window_and_padding(win_input1, input1_access)
1289 || update_window_and_padding(win_input2, input2_access)
1290 || update_window_and_padding(win, output_access);
1291
1292 output_access.set_valid_region(win, valid_region);
1293
1294 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1295 return std::make_pair(err, win);
1296}
1297} // namespace
1298
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001299void NEComplexPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
giuros01154bc1c2019-03-26 17:44:40 +00001300{
1301 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001302 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1, input2, output));
giuros01154bc1c2019-03-26 17:44:40 +00001303
1304 // Configure kernel window
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001305 auto win_config = validate_and_configure_window_complex(input1, input2, output);
giuros01154bc1c2019-03-26 17:44:40 +00001306 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1307
giuros01154bc1c2019-03-26 17:44:40 +00001308 // Create kernel
1309 INEKernel::configure(win_config.second);
1310}
1311
1312Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1313{
1314 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1315 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output));
1316 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_complex(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
1317
1318 return Status{};
1319}
1320
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001321void NEComplexPixelWiseMultiplicationKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, const ThreadInfo &info)
giuros01154bc1c2019-03-26 17:44:40 +00001322{
1323 ARM_COMPUTE_UNUSED(info);
1324 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1325 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1326
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001327 auto input1 = inputs.at(TensorType::ACL_SRC_0);
1328 auto input2 = inputs.at(TensorType::ACL_SRC_1);
1329 auto output = outputs.at(TensorType::ACL_DST);
1330
1331 Iterator input1_it(input1, window.broadcast_if_dimension_le_one(input1->info()->tensor_shape()));
1332 Iterator input2_it(input2, window.broadcast_if_dimension_le_one(input2->info()->tensor_shape()));
1333 Iterator output_it(output, window);
giuros01154bc1c2019-03-26 17:44:40 +00001334
1335 execute_window_loop(window, [&](const Coordinates &)
1336 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001337 c_mul_F32_F32_F32_n(input1_it.ptr(), input2_it.ptr(), output_it.ptr());
giuros01154bc1c2019-03-26 17:44:40 +00001338 },
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001339 input1_it, input2_it, output_it);
giuros01154bc1c2019-03-26 17:44:40 +00001340}
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001341} // namespace arm_compute