blob: 907a7f197b49ba597f1a0398c85d6be6660094bd [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2016-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Manuel Bottini79fa9a22019-02-22 17:54:22 +000027#include "arm_compute/core/NEON/NEAsymm.h"
Manuel Bottini7bb56c62019-06-26 15:17:09 +010028#include "arm_compute/core/NEON/NESymm.h"
giuros01154bc1c2019-03-26 17:44:40 +000029#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031
32#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000034#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellodf246182017-07-03 16:25:09 +010035#include <arm_fp16.h> // needed for float16_t
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000036#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +010037
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038namespace arm_compute
39{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040namespace
41{
42const float scale255_constant = 1.f / 255.f;
43const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
44const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
45
Georgios Pinitas631c41a2017-12-06 11:53:03 +000046inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000047{
48 ARM_COMPUTE_UNUSED(overflow_policy);
49 ARM_COMPUTE_UNUSED(rounding_policy);
50
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
Pablo Tello52ea9c22019-12-10 11:28:53 +000052 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
53 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
Michele Di Giorgio9428a182020-03-30 14:10:20 +010054 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
55 DataType::S16, DataType::QSYMM16,
56 DataType::S32, DataType::F16, DataType::F32);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000057 if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
Pablo Tello52ea9c22019-12-10 11:28:53 +000058 {
59 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000060 ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
Pablo Tello52ea9c22019-12-10 11:28:53 +000061 }
Manuel Bottini79fa9a22019-02-22 17:54:22 +000062
63 if(output->total_size() > 0)
64 {
Manuel Bottini79fa9a22019-02-22 17:54:22 +000065 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
66 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
67 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
Michele Di Giorgio9428a182020-03-30 14:10:20 +010068
69 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
70 "Output can only be U8 if both inputs are U8");
Michele Di Giorgiof9b595a2020-07-03 13:34:52 +010071 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8 && (input1->data_type() != DataType::QASYMM8 || input2->data_type() != DataType::QASYMM8),
72 "Output can only be QASYMM8 if both inputs are QASYMM8");
73 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8_SIGNED && (input1->data_type() != DataType::QASYMM8_SIGNED || input2->data_type() != DataType::QASYMM8_SIGNED),
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +010074 "Output can only be QASYMM8_SIGNED if both inputs are QASYMM8_SIGNED");
Michele Di Giorgiof9b595a2020-07-03 13:34:52 +010075 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QSYMM16 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16),
76 "Output can only be QSYMM16 if both inputs are QSYMM16");
Michele Di Giorgio9428a182020-03-30 14:10:20 +010077 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16),
78 "Output can only be S32 if both inputs are QSYMM16");
79 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
Manuel Bottini79fa9a22019-02-22 17:54:22 +000080 }
Michalis Spyrou861f0db2018-02-26 16:47:58 +000081
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000082 if(std::abs(scale - scale255_constant) < 0.00001f)
83 {
84 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
85 }
86 else
87 {
88 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
89
90 int exponent = 0;
91 const float normalized_mantissa = std::frexp(scale, &exponent);
92
93 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
94 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
95 // Moreover, it will be negative as we deal with 1/2^n
96 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
97 }
98
Georgios Pinitas631c41a2017-12-06 11:53:03 +000099 return Status{};
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000100}
101
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100102/* Scales a given vector by 1/255.
103 *
104 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
105 *
106 * @param in Input vector to scale.
107 * @return Scaled output rounded to nearest (round half up).
108 */
109inline int32x4_t scale255_S32_S32(int32x4_t in)
110{
111 // Scale
112 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
113 // Round to nearest (round half up)
114 // Add +0.5 for all values
115 // Afterwards vcvt rounds toward zero
116 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
117}
118
119inline uint16x8_t scale255_U16_U16(uint16x8_t in)
120{
121 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
122 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
123 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
124}
125
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100126template <typename T>
127inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
128vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000129{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100130 return vquantize_signed(val, info);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000131}
132
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100133template <typename T>
134inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
135vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Pablo Tello52ea9c22019-12-10 11:28:53 +0000136{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100137 return vquantize(val, info);
Pablo Tello52ea9c22019-12-10 11:28:53 +0000138}
139
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100140template <typename T>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100141void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
142{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100143 // Create input windows
144 Window win = window;
145 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
146 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
147
148 // Clear X Dimension on execution window as we handle manually
149 win.set(Window::DimX, Window::Dimension(0, 1, 1));
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100150
Sheri Zhanga449a362020-07-16 15:52:25 +0100151 const int window_step_x = 16 / sizeof(T);
152 const auto window_start_x = static_cast<int>(window.x().start());
153 const auto window_end_x = static_cast<int>(window.x().end());
154 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100155
Sheri Zhanga449a362020-07-16 15:52:25 +0100156 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
157 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100158
Sheri Zhanga449a362020-07-16 15:52:25 +0100159 if(is_broadcast_across_x)
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100160 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100161 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
162 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
163 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
164 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
165 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
166 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
167 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100168
Sheri Zhanga449a362020-07-16 15:52:25 +0100169 // Clear X Dimension on execution window as we handle manually
170 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
171
172 Iterator broadcast_input(broadcast_tensor, broadcast_win);
173 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
174 Iterator output(out, win);
175
176 using ExactTagType = typename wrapper::traits::neon_vector<T, window_step_x>::tag_type;
177
178 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100179 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100180 const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
181 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100182
Sheri Zhanga449a362020-07-16 15:52:25 +0100183 const auto broadcast_value = *reinterpret_cast<const T *>(broadcast_input.ptr());
184 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100185
Sheri Zhanga449a362020-07-16 15:52:25 +0100186 // Compute window_step_x elements per iteration
187 int x = window_start_x;
188 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100189 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100190 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100191
Sheri Zhanga449a362020-07-16 15:52:25 +0100192 // Dequantize inputs
193 const float32x4x4_t in1_f32x4x4 = vdequantize(non_broadcast_v, non_broadcast_qinfo);
194 const float32x4x4_t in2_f32x4x4 = vdequantize(broadcast_value_vec, broadcast_qinfo);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100195
Sheri Zhanga449a362020-07-16 15:52:25 +0100196 const float32x4x4_t out_f32x4x4 =
197 {
198 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
199 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
200 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
201 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
202 };
203
204 // Quantize output
205 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
206 wrapper::vstore(output_ptr + x, result);
207 }
208
209 // Compute left-over elements
210 for(; x < window_end_x; ++x)
211 {
212 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100213 const T in1 = *(non_broadcast_input_ptr + x);
214 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, non_broadcast_qinfo);
215 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(broadcast_value, broadcast_qinfo);
216 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100217
218 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100219 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100220 *(output_ptr + x) = tmp_qua;
221 }
222 },
223 broadcast_input, non_broadcast_input, output);
224 }
225 else
226 {
227 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
228 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
229
230 // Clear X Dimension on execution window as we handle manually
231 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
232 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
233
234 Iterator input1(in1, input1_win);
235 Iterator input2(in2, input2_win);
236 Iterator output(out, win);
237
238 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100239 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100240 const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
241 const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
242 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100243
Sheri Zhanga449a362020-07-16 15:52:25 +0100244 // Compute window_step_x elements per iteration
245 int x = window_start_x;
246 for(; x <= (window_end_x - window_step_x); x += window_step_x)
247 {
248 const auto input1_q = wrapper::vloadq(input1_ptr + x);
249 const auto input2_q = wrapper::vloadq(input2_ptr + x);
250
251 // Dequantize inputs
252 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
253 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
254
255 const float32x4x4_t out_f32x4x4 =
256 {
257 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
258 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
259 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
260 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
261 };
262
263 // Quantize output
264 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
265 wrapper::vstore(output_ptr + x, result);
266 }
267
268 // Compute left-over elements
269 for(; x < window_end_x; ++x)
270 {
271 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100272 const T in1 = *(input1_ptr + x);
273 const T in2 = *(input2_ptr + x);
274 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, input1_qua_info);
275 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(in2, input2_qua_info);
276 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100277
278 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100279 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100280 *(output_ptr + x) = tmp_qua;
281 }
282 },
283 input1, input2, output);
284 }
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100285}
286
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100287void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
288{
289 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
290 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
291 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
292
293 // Create input windows
294 Window win = window;
295 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
296 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
297
298 // Clear X Dimension on execution window as we handle manually
299 win.set(Window::DimX, Window::Dimension(0, 1, 1));
300 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
301 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
302
303 Iterator input1(in1, input1_win);
304 Iterator input2(in2, input2_win);
305 Iterator output(out, win);
306
307 const int window_step_x = 16;
308 const auto window_start_x = static_cast<int>(window.x().start());
309 const auto window_end_x = static_cast<int>(window.x().end());
310
311 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
312
313 execute_window_loop(win, [&](const Coordinates &)
314 {
315 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
316 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
317 const auto output_ptr = reinterpret_cast<qsymm16_t *>(output.ptr());
318
319 // Compute window_step_x elements per iteration
320 int x = window_start_x;
321 for(; x <= (window_end_x - window_step_x); x += window_step_x)
322 {
323 const qsymm16x8x2_t input1_q =
324 {
325 {
326 vld1q_s16(input1_ptr + x),
327 vld1q_s16(input1_ptr + x + 8),
328 }
329 };
330 const qsymm16x8x2_t input2_q =
331 {
332 {
333 vld1q_s16(input2_ptr + x),
334 vld1q_s16(input2_ptr + x + 8),
335 }
336 };
337
338 // Dequantize inputs
339 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
340 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
341
342 const float32x4x4_t out_f32x4x4 =
343 {
344 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
345 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
346 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
347 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
348 };
349
350 const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
351 vst1q_s16(output_ptr + x, result.val[0]);
352 vst1q_s16(output_ptr + x + 8, result.val[1]);
353 }
354
355 // Compute left-over elements
356 for(; x < window_end_x; ++x)
357 {
358 // Dequantize inputs
359 float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
360 float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
361 float tmp_f = tmp_in1 * tmp_in2;
362
363 // Quantize output, lrintf() has same rounding mode as vcombine_s16
364 int32_t tmp = lrintf(tmp_f / tmp_qua_info.scale);
365 qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
366 *(output_ptr + x) = tmp_qua;
367 }
368 },
369 input1, input2, output);
370}
371
372void mul_QSYMM16_QSYMM16_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int scale)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100373{
374 ARM_COMPUTE_UNUSED(scale);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100375
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100376 // Create input windows
377 Window win = window;
378 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
379 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100380
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100381 // Clear X Dimension on execution window as we handle manually
382 win.set(Window::DimX, Window::Dimension(0, 1, 1));
383 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
384 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100385
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100386 Iterator input1(in1, input1_win);
387 Iterator input2(in2, input2_win);
388 Iterator output(out, win);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100389
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100390 const int window_step_x = 16;
391 const auto window_start_x = static_cast<int>(window.x().start());
392 const auto window_end_x = static_cast<int>(window.x().end());
393
394 execute_window_loop(win, [&](const Coordinates &)
395 {
396 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
397 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
398 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
399
400 // Compute window_step_x elements per iteration
401 int x = window_start_x;
402 for(; x <= (window_end_x - window_step_x); x += window_step_x)
403 {
404 const qsymm16x8x2_t input1_q =
405 {
406 {
407 vld1q_s16(input1_ptr + x),
408 vld1q_s16(input1_ptr + x + 8),
409 }
410 };
411 const qsymm16x8x2_t input2_q =
412 {
413 {
414 vld1q_s16(input2_ptr + x),
415 vld1q_s16(input2_ptr + x + 8),
416 }
417 };
418
419 const int32x4x4_t in1_s32 =
420 {
421 {
422 vmovl_s16(vget_low_s16(input1_q.val[0])),
423 vmovl_s16(vget_high_s16(input1_q.val[0])),
424 vmovl_s16(vget_low_s16(input1_q.val[1])),
425 vmovl_s16(vget_high_s16(input1_q.val[1])),
426 }
427 };
428 const int32x4x4_t in2_s32 =
429 {
430 {
431 vmovl_s16(vget_low_s16(input2_q.val[0])),
432 vmovl_s16(vget_high_s16(input2_q.val[0])),
433 vmovl_s16(vget_low_s16(input2_q.val[1])),
434 vmovl_s16(vget_high_s16(input2_q.val[1])),
435 }
436 };
437
438 const int32x4x4_t result =
439 {
440 {
441 vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
442 vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
443 vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
444 vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
445 }
446 };
447
448 vst1q_s32(output_ptr + x, result.val[0]);
449 vst1q_s32(output_ptr + x + 4, result.val[1]);
450 vst1q_s32(output_ptr + x + 8, result.val[2]);
451 vst1q_s32(output_ptr + x + 12, result.val[3]);
452 }
453
454 // Compute left-over elements
455 for(; x < window_end_x; ++x)
456 {
457 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
458 *(output_ptr + x) = tmp;
459 }
460 },
461 input1, input2, output);
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100462}
463
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100464template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100465void mul_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100466{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100467 // Create input windows
468 Window win = window;
469 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
470 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100471
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100472 // Clear X Dimension on execution window as we handle manually
473 win.set(Window::DimX, Window::Dimension(0, 1, 1));
474 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
475 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100476
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100477 Iterator input1(in1, input1_win);
478 Iterator input2(in2, input2_win);
479 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100480
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100481 const int window_step_x = 16 / sizeof(uint8_t);
482 const auto window_start_x = static_cast<int>(window.x().start());
483 const auto window_end_x = static_cast<int>(window.x().end());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100484
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100485 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100486 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100487 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
488 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
489 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100490
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100491 // Compute window_step_x elements per iteration
492 int x = window_start_x;
493 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100494 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100495 const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
496 const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100497
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100498 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
499 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
500 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
501 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
502
503 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
504 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
505
506 if(is_scale255)
507 {
508 tmp1_high = scale255_U16_U16(tmp1_high);
509 tmp1_low = scale255_U16_U16(tmp1_low);
510 }
511 else
512 {
513 const int16x8_t vn = vdupq_n_s16(-n);
514
515 if(is_sat)
516 {
517 tmp1_high = vqshlq_u16(tmp1_high, vn);
518 tmp1_low = vqshlq_u16(tmp1_low, vn);
519 }
520 else
521 {
522 tmp1_high = vshlq_u16(tmp1_high, vn);
523 tmp1_low = vshlq_u16(tmp1_low, vn);
524 }
525 }
526 if(is_sat)
527 {
528 vst1q_u8(output_ptr, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
529 }
530 else
531 {
532 vst1q_u8(output_ptr, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
533 }
534 }
535
536 // Compute left-over elements
537 for(; x < window_end_x; ++x)
538 {
539 uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
540
541 if(is_scale255)
542 {
543 float tmp_f = static_cast<float>(tmp) * scale255_constant;
544 tmp = static_cast<uint16_t>(tmp_f + 0.5f);
545 }
546 else
547 {
548 tmp >>= n;
549 }
550 if(is_sat && tmp > 255)
551 {
552 tmp = 255;
553 }
554 *(output_ptr + x) = static_cast<uint8_t>(tmp);
555 }
556 },
557 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100558}
559
560template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100561inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
562{
563 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
564 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
565 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
566 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
567
568 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
569 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
570
571 if(is_scale255)
572 {
573 tmp1_high = scale255_S32_S32(tmp1_high);
574 tmp1_low = scale255_S32_S32(tmp1_low);
575 }
576 else
577 {
578 // Right shift amount
579 const int32x4_t vn = vdupq_n_s32(-n);
580 // Left shift amount
581 const int32x4_t vnl = vdupq_n_s32(n);
582 // Calculate conversion bit
583 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
584 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
585 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
586 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
587 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
588 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
589 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
590 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
591 if(is_sat)
592 {
593 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
594 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
595 }
596 else
597 {
598 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
599 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
600 }
601 }
602
603 if(is_sat)
604 {
605 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
606 }
607 else
608 {
609 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
610 }
611}
612
613template <bool is_scale255, bool is_sat>
614inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
615{
616 const int16x8x2_t result =
617 {
618 {
619 // First 8 elements
620 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
621 // Second 8 elements
622 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
623 }
624 };
625
626 return result;
627}
628
629template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100630void mul_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100631{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100632 // Create input windows
633 Window win = window;
634 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
635 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100636
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100637 // Clear X Dimension on execution window as we handle manually
638 win.set(Window::DimX, Window::Dimension(0, 1, 1));
639 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
640 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100641
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100642 Iterator input1(in1, input1_win);
643 Iterator input2(in2, input2_win);
644 Iterator output(out, win);
645
646 const int window_step_x = 16;
647 const auto window_start_x = static_cast<int>(window.x().start());
648 const auto window_end_x = static_cast<int>(window.x().end());
649
650 execute_window_loop(win, [&](const Coordinates &)
651 {
652 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
653 const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
654 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
655
656 // Compute window_step_x elements per iteration
657 int x = window_start_x;
658 for(; x <= (window_end_x - window_step_x); x += window_step_x)
659 {
660 const int16x8x2_t ta1 =
661 {
662 {
663 vld1q_s16(input1_ptr + x),
664 vld1q_s16(input1_ptr + x + 8),
665 }
666 };
667 const int16x8x2_t ta2 =
668 {
669 {
670 vld1q_s16(input2_ptr + x),
671 vld1q_s16(input2_ptr + x + 8),
672 }
673 };
674 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
675
676 vst1q_s16(output_ptr + x, result.val[0]);
677 vst1q_s16(output_ptr + x + 8, result.val[1]);
678 }
679
680 // Compute left-over elements
681 for(; x < window_end_x; ++x)
682 {
683 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
684
685 if(is_scale255)
686 {
687 float tmp_f = static_cast<float>(tmp) * scale255_constant;
688
689 tmp = static_cast<int32_t>(tmp_f + 0.5f);
690 }
691 else
692 {
693 if(tmp >= 0)
694 {
695 tmp >>= n;
696 }
697 else
698 {
699 uint32_t mask = (1u << n) - 1;
700 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
701 }
702 }
703 if(is_sat)
704 {
705 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
706 }
707 *(output_ptr + x) = static_cast<int16_t>(tmp);
708 }
709 },
710 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100711}
712
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100713void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100714{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100715 // Create input windows
716 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
717 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100718
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100719 // Clear X Dimension on execution window as we handle manually
720 Window win = window;
721 win.set(Window::DimX, Window::Dimension(0, 1, 1));
722
723 constexpr int window_step_x = 16 / sizeof(float);
724 const auto window_start_x = static_cast<int>(window.x().start());
725 const auto window_end_x = static_cast<int>(window.x().end());
726 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
727
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100728 using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
729
730 if(is_broadcast_across_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100731 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100732 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
733 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
734 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
735 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
736 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
737
738 // Clear X Dimension on execution window as we handle manually
739 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
740
741 Iterator broadcast_input(broadcast_tensor, broadcast_win);
742 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
743 Iterator output(out, win);
744
745 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100746 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100747 const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
748 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
749
750 const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
751 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
752 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
753
754 // Compute window_step_x elements per iteration
755 int x = window_start_x;
756 for(; x <= (window_end_x - window_step_x); x += window_step_x)
757 {
758 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
759 auto res = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
760 wrapper::vstore(output_ptr + x, res);
761 }
762
763 // Compute left-over elements
764 for(; x < window_end_x; ++x)
765 {
766 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
767 *(output_ptr + x) = broadcast_value * non_broadcast_v * scale;
768 }
769 },
770 broadcast_input, non_broadcast_input, output);
771 }
772 else
773 {
774 // Clear X Dimension on execution window as we handle manually
775 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
776 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
777
778 Iterator input1(in1, input1_win);
779 Iterator input2(in2, input2_win);
780 Iterator output(out, win);
781
782 execute_window_loop(win, [&](const Coordinates &)
783 {
784 const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
785 const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
786 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
787
788 // Compute window_step_x elements per iteration
789 int x = window_start_x;
790 for(; x <= (window_end_x - window_step_x); x += window_step_x)
791 {
792 const auto ta1 = wrapper::vloadq(input1_ptr + x);
793 const auto ta2 = wrapper::vloadq(input2_ptr + x);
794 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
795 const auto res = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
796 wrapper::vstore(output_ptr + x, res);
797 }
798
799 // Compute left-over elements
800 for(; x < window_end_x; ++x)
801 {
802 const auto ta1 = *(input1_ptr + x);
803 const auto ta2 = *(input2_ptr + x);
804 *(output_ptr + x) = ta1 * ta2 * scale;
805 }
806 },
807 input1, input2, output);
808 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100809}
810
giuros01154bc1c2019-03-26 17:44:40 +0000811void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr)
812{
813 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
814 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
815 const auto output = static_cast<float *__restrict>(output_ptr);
816
817 const float32x4_t a = wrapper::vloadq(input1);
818 float32x4_t b = wrapper::vloadq(input2);
819
820 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
821
822 const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
823 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
824 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
825 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
826 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
827
828 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
829 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
830
831 float32x4_t res = wrapper::vmul(tmp0, b);
832
833 b = wrapper::vrev64(b);
834 b = wrapper::vmul(b, mask);
835
836 res = wrapper::vmla(res, tmp1, b);
837 wrapper::vstore(output, res);
838}
839
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000840#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100841void mul_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
842{
843 // Create input windows
844 Window win = window;
845 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
846 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
847
848 // Clear X Dimension on execution window as we handle manually
849 win.set(Window::DimX, Window::Dimension(0, 1, 1));
850 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
851 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
852
853 Iterator input1(in1, input1_win);
854 Iterator input2(in2, input2_win);
855 Iterator output(out, win);
856
857 const int window_step_x = 16;
858 const auto window_start_x = static_cast<int>(window.x().start());
859 const auto window_end_x = static_cast<int>(window.x().end());
860
861 execute_window_loop(win, [&](const Coordinates &)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100862 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100863 const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
864 const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
865 const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
866
867 // Compute window_step_x elements per iteration
868 int x = window_start_x;
869 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100870 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100871 const float16x8x2_t ta1 =
872 {
873 {
874 vld1q_f16(input1_ptr + x),
875 vld1q_f16(input1_ptr + x + 8),
876 }
877 };
878 const float16x8x2_t ta2 =
879 {
880 {
881 vld1q_f16(input2_ptr + x),
882 vld1q_f16(input2_ptr + x + 8),
883 }
884 };
885 const float16x8_t scale_vec = vdupq_n_f16(scale);
886 const float16x8x2_t result =
887 {
888 {
889 vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
890 vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
891 }
892 };
893 vst1q_f16(output_ptr + x, result.val[0]);
894 vst1q_f16(output_ptr + x + 8, result.val[1]);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100895 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100896
897 // Compute left-over elements
898 for(; x < window_end_x; ++x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100899 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100900 const auto ta1 = *(input1_ptr + x);
901 const auto ta2 = *(input2_ptr + x);
902 *(output_ptr + x) = ta1 * ta2 * scale;
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100903 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100904 },
905 input1, input2, output);
906}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000907#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +0100908
909template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100910void mul_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100911{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100912 // Create input windows
913 Window win = window;
914 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
915 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100916
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100917 // Clear X Dimension on execution window as we handle manually
918 win.set(Window::DimX, Window::Dimension(0, 1, 1));
919 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
920 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100921
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100922 Iterator input1(in1, input1_win);
923 Iterator input2(in2, input2_win);
924 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100925
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100926 const int window_step_x = 16 / sizeof(uint8_t);
927 const auto window_start_x = static_cast<int>(window.x().start());
928 const auto window_end_x = static_cast<int>(window.x().end());
929
930 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100931 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100932 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
933 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
934 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100935
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100936 // Compute window_step_x elements per iteration
937 int x = window_start_x;
938 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100939 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100940 const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
941 const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
942
943 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
944 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
945 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
946 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
947
948 if(is_scale255)
949 {
950 tmp_low = scale255_U16_U16(tmp_low);
951 tmp_high = scale255_U16_U16(tmp_high);
952 }
953 else
954 {
955 const int16x8_t vn = vdupq_n_s16(-n);
956
957 if(is_sat)
958 {
959 tmp_low = vqshlq_u16(tmp_low, vn);
960 tmp_high = vqshlq_u16(tmp_high, vn);
961 }
962 else
963 {
964 tmp_low = vshlq_u16(tmp_low, vn);
965 tmp_high = vshlq_u16(tmp_high, vn);
966 }
967 }
968
969 if(is_sat)
970 {
971 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
972
973 tmp_low = vminq_u16(tmp_low, max);
974 tmp_high = vminq_u16(tmp_high, max);
975 }
976
977 vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
978 vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100979 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100980
981 // Compute left-over elements
982 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100983 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100984 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
985
986 if(is_scale255)
987 {
988 float tmp_f = static_cast<float>(tmp) * scale255_constant;
989 tmp = static_cast<int32_t>(tmp_f + 0.5f);
990 }
991 else
992 {
993 tmp >>= n;
994 }
995
996 if(is_sat)
997 {
998 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
999 }
1000
1001 *(output_ptr + x) = static_cast<int16_t>(tmp);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001002 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001003 },
1004 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001005}
1006
1007template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001008void mul_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001009{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001010 // Create input windows
1011 Window win = window;
1012 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1013 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001014
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001015 // Clear X Dimension on execution window as we handle manually
1016 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1017 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1018 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001019
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001020 Iterator input1(in1, input1_win);
1021 Iterator input2(in2, input2_win);
1022 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001023
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001024 const int window_step_x = 16;
1025 const auto window_start_x = static_cast<int>(window.x().start());
1026 const auto window_end_x = static_cast<int>(window.x().end());
1027
1028 execute_window_loop(win, [&](const Coordinates &)
1029 {
1030 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
1031 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1032 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
1033
1034 // Compute window_step_x elements per iteration
1035 int x = window_start_x;
1036 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1037 {
1038 const int16x8x2_t ta1 =
1039 {
1040 {
1041 vld1q_s16(input1_ptr + x),
1042 vld1q_s16(input1_ptr + x + 8),
1043 }
1044 };
1045 const uint8x8x2_t ta2u =
1046 {
1047 {
1048 vld1_u8(input2_ptr + x),
1049 vld1_u8(input2_ptr + x + 8),
1050 }
1051 };
1052 const int16x8x2_t ta2 =
1053 {
1054 {
1055 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1056 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1057 }
1058 };
1059
1060 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1061
1062 vst1q_s16(output_ptr + x, result.val[0]);
1063 vst1q_s16(output_ptr + x + 8, result.val[1]);
1064 }
1065
1066 // Compute left-over elements
1067 for(; x < window_end_x; ++x)
1068 {
1069 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1070
1071 if(is_scale255)
1072 {
1073 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1074
1075 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1076 }
1077 else
1078 {
1079 if(tmp >= 0)
1080 {
1081 tmp >>= n;
1082 }
1083 else
1084 {
1085 uint32_t mask = (1u << n) - 1;
1086 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
1087 }
1088 }
1089 if(is_sat)
1090 {
1091 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1092 }
1093 *(output_ptr + x) = static_cast<int16_t>(tmp);
1094 }
1095 },
1096 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001097}
1098
1099template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001100void mul_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001101{
1102 // Simply swap the two input buffers
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001103 mul_S16_U8_S16<is_scale255, is_sat>(in2, in1, out, window, n);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001104}
1105} // namespace
1106
1107NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001108 : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001109{
1110}
1111
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001112void NEPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001113{
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001114 ARM_COMPUTE_UNUSED(rounding_policy);
Georgios Pinitasf0dea702017-07-03 18:17:28 +01001115 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1116
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001117 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001118
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001119 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001120 const TensorShape &out_shape = broadcast_pair.first;
1121 const ValidRegion &valid_region = broadcast_pair.second;
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001122
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001123 // Auto initialize output if not initialized
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001124 set_shape_if_empty(*output, out_shape);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001125
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001126 _scale = scale;
1127 _scale_exponent = 0;
1128 _func_quantized = nullptr;
1129 _func_int = nullptr;
1130 _func_float = nullptr;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001131
1132 bool is_scale_255 = false;
1133 // Check and validate scaling factor
1134 if(std::abs(scale - scale255_constant) < 0.00001f)
1135 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001136 is_scale_255 = true;
1137 }
1138 else
1139 {
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001140 int exponent = 0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001141
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001142 std::frexp(scale, &exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001143
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001144 // Store the positive exponent. We know that we compute 1/2^n
1145 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1146 _scale_exponent = std::abs(exponent - 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001147 }
1148
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001149 const DataType dt_input1 = input1->data_type();
1150 const DataType dt_input2 = input2->data_type();
1151 const DataType dt_output = output->data_type();
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001152 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
1153
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001154 switch(dt_input1)
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001155 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001156 case DataType::QASYMM8:
1157 if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1158 {
1159 _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1160 }
1161 break;
1162 case DataType::QASYMM8_SIGNED:
1163 if(dt_input2 == DataType::QASYMM8_SIGNED)
1164 {
1165 _func_quantized = &mul_saturate_quantized_8<int8_t>;
1166 ;
1167 }
1168 break;
1169 case DataType::QSYMM16:
1170 if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1171 {
1172 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1173 }
1174 else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1175 {
1176 _func_int = &mul_QSYMM16_QSYMM16_S32;
1177 }
1178 break;
1179 case DataType::S16:
1180 if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1181 {
1182 if(is_scale_255)
1183 {
1184 _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1185 }
1186 else
1187 {
1188 _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1189 }
1190 }
1191 if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1192 {
1193 if(is_scale_255)
1194 {
1195 _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1196 }
1197 else
1198 {
1199 _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1200 }
1201 }
1202 break;
1203 case DataType::U8:
1204 if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1205 {
1206 if(is_scale_255)
1207 {
1208 _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1209 }
1210 else
1211 {
1212 _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1213 }
1214 }
1215 else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1216 {
1217 if(is_scale_255)
1218 {
1219 _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1220 }
1221 else
1222 {
1223 _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1224 }
1225 }
1226 else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1227 {
1228 if(is_scale_255)
1229 {
1230 _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1231 }
1232 else
1233 {
1234 _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1235 }
1236 }
1237 break;
1238#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1239 case DataType::F16:
1240 _func_float = &mul_F16_F16_F16;
1241 break;
1242#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1243 case DataType::F32:
1244 _func_float = &mul_F32_F32_F32;
1245 break;
1246 default:
1247 ARM_COMPUTE_ERROR("You called with the wrong img formats");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001248 }
1249
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001250 // Configure kernel window
1251 Coordinates coord;
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001252 coord.set_num_dimensions(output->num_dimensions());
1253 output->set_valid_region(valid_region);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001254 Window win = calculate_max_window(valid_region, Steps());
1255
1256 INEKernel::configure(win);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001257}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001258
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001259Status NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
1260 RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001261{
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001262 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001263 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001264
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001265 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001266}
1267
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001268void NEPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001269{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001270 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001271 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1272 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1273
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001274 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1275 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1276 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001277
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001278 if(_func_quantized != nullptr)
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001279 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001280 (*_func_quantized)(input1, input2, output, window, _scale);
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001281 }
1282 else if(_func_int != nullptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001283 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001284 (*_func_int)(input1, input2, output, window, _scale_exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001285 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001286 else
1287 {
1288 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001289 (*_func_float)(input1, input2, output, window, _scale);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001290 }
1291}
giuros01154bc1c2019-03-26 17:44:40 +00001292namespace
1293{
1294constexpr unsigned int num_elems_processed_per_iteration_complex = 2;
1295
1296Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1297{
1298 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32);
1299 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32);
1300
1301 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
1302
1303 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1304
1305 // Validate in case of configured output
1306 if(output->total_size() > 0)
1307 {
1308 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32);
1309 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
1310 }
1311
1312 return Status{};
1313}
1314
1315std::pair<Status, Window> validate_and_configure_window_complex(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
1316{
1317 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
1318 const TensorShape &out_shape = broadcast_pair.first;
1319 const ValidRegion &valid_region = broadcast_pair.second;
1320
1321 // Auto initialize output if not initialized
1322 const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type());
1323 auto_init_if_empty(*output, out_info);
1324
1325 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration_complex));
1326 Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
1327 Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
1328
1329 AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration_complex);
1330 AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_complex);
1331 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_complex);
1332
1333 bool window_changed = update_window_and_padding(win_input1, input1_access)
1334 || update_window_and_padding(win_input2, input2_access)
1335 || update_window_and_padding(win, output_access);
1336
1337 output_access.set_valid_region(win, valid_region);
1338
1339 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1340 return std::make_pair(err, win);
1341}
1342} // namespace
1343
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001344void NEComplexPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
giuros01154bc1c2019-03-26 17:44:40 +00001345{
1346 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001347 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1, input2, output));
giuros01154bc1c2019-03-26 17:44:40 +00001348
1349 // Configure kernel window
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001350 auto win_config = validate_and_configure_window_complex(input1, input2, output);
giuros01154bc1c2019-03-26 17:44:40 +00001351 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1352
giuros01154bc1c2019-03-26 17:44:40 +00001353 // Create kernel
1354 INEKernel::configure(win_config.second);
1355}
1356
1357Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1358{
1359 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1360 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output));
1361 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_complex(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
1362
1363 return Status{};
1364}
1365
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001366void NEComplexPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
giuros01154bc1c2019-03-26 17:44:40 +00001367{
1368 ARM_COMPUTE_UNUSED(info);
1369 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1370 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1371
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001372 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1373 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1374 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001375
1376 Iterator input1_it(input1, window.broadcast_if_dimension_le_one(input1->info()->tensor_shape()));
1377 Iterator input2_it(input2, window.broadcast_if_dimension_le_one(input2->info()->tensor_shape()));
1378 Iterator output_it(output, window);
giuros01154bc1c2019-03-26 17:44:40 +00001379
1380 execute_window_loop(window, [&](const Coordinates &)
1381 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001382 c_mul_F32_F32_F32_n(input1_it.ptr(), input2_it.ptr(), output_it.ptr());
giuros01154bc1c2019-03-26 17:44:40 +00001383 },
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001384 input1_it, input2_it, output_it);
giuros01154bc1c2019-03-26 17:44:40 +00001385}
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001386} // namespace arm_compute