blob: 84683ea69fa2b0258f14d1625796caa739e59c77 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2016-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Manuel Bottini79fa9a22019-02-22 17:54:22 +000027#include "arm_compute/core/NEON/NEAsymm.h"
Manuel Bottini7bb56c62019-06-26 15:17:09 +010028#include "arm_compute/core/NEON/NESymm.h"
giuros01154bc1c2019-03-26 17:44:40 +000029#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031
32#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000034#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellodf246182017-07-03 16:25:09 +010035#include <arm_fp16.h> // needed for float16_t
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000036#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +010037
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038namespace arm_compute
39{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040namespace
41{
42const float scale255_constant = 1.f / 255.f;
43const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
44const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
45
Georgios Pinitas631c41a2017-12-06 11:53:03 +000046inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000047{
48 ARM_COMPUTE_UNUSED(overflow_policy);
49 ARM_COMPUTE_UNUSED(rounding_policy);
50
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
SiCong Libb88f892020-08-28 11:18:47 +010052 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
53 DataType::F32);
54 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
55 DataType::F32);
Michele Di Giorgio9428a182020-03-30 14:10:20 +010056 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
57 DataType::S16, DataType::QSYMM16,
58 DataType::S32, DataType::F16, DataType::F32);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000059 if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
Pablo Tello52ea9c22019-12-10 11:28:53 +000060 {
61 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000062 ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
Pablo Tello52ea9c22019-12-10 11:28:53 +000063 }
Manuel Bottini79fa9a22019-02-22 17:54:22 +000064
65 if(output->total_size() > 0)
66 {
Manuel Bottini79fa9a22019-02-22 17:54:22 +000067 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
68 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
69 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
SiCong Libb88f892020-08-28 11:18:47 +010070 // clang-format off
71 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
72 !(input1->data_type() == input2->data_type() && input2->data_type() == output->data_type()) &&
73 !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
74 !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) &&
75 !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
76 !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
77 !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::S32)
78 , "Invalid data type combination");
79 // clang-format on
80 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S16 && output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
Manuel Bottini79fa9a22019-02-22 17:54:22 +000081 }
Michalis Spyrou861f0db2018-02-26 16:47:58 +000082
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000083 if(std::abs(scale - scale255_constant) < 0.00001f)
84 {
85 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
SiCong Libb88f892020-08-28 11:18:47 +010086 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32,
87 "Scale == 1/255 is not supported if input and output are of data type S32");
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000088 }
89 else
90 {
91 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
92
93 int exponent = 0;
94 const float normalized_mantissa = std::frexp(scale, &exponent);
95
96 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
97 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
98 // Moreover, it will be negative as we deal with 1/2^n
99 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
100 }
101
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000102 return Status{};
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000103}
104
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105/* Scales a given vector by 1/255.
106 *
107 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
108 *
109 * @param in Input vector to scale.
110 * @return Scaled output rounded to nearest (round half up).
111 */
112inline int32x4_t scale255_S32_S32(int32x4_t in)
113{
114 // Scale
115 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
116 // Round to nearest (round half up)
117 // Add +0.5 for all values
118 // Afterwards vcvt rounds toward zero
119 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
120}
121
122inline uint16x8_t scale255_U16_U16(uint16x8_t in)
123{
124 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
125 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
126 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
127}
128
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100129template <typename T>
130inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
131vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000132{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100133 return vquantize_signed(val, info);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000134}
135
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100136template <typename T>
137inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
138vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Pablo Tello52ea9c22019-12-10 11:28:53 +0000139{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100140 return vquantize(val, info);
Pablo Tello52ea9c22019-12-10 11:28:53 +0000141}
142
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100143template <typename T>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100144void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
145{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100146 // Create input windows
147 Window win = window;
148 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
149 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
150
151 // Clear X Dimension on execution window as we handle manually
152 win.set(Window::DimX, Window::Dimension(0, 1, 1));
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100153
Sheri Zhanga449a362020-07-16 15:52:25 +0100154 const int window_step_x = 16 / sizeof(T);
155 const auto window_start_x = static_cast<int>(window.x().start());
156 const auto window_end_x = static_cast<int>(window.x().end());
157 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100158
Sheri Zhanga449a362020-07-16 15:52:25 +0100159 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
160 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100161
Sheri Zhanga449a362020-07-16 15:52:25 +0100162 if(is_broadcast_across_x)
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100163 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100164 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
165 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
166 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
167 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
168 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
169 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
170 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100171
Sheri Zhanga449a362020-07-16 15:52:25 +0100172 // Clear X Dimension on execution window as we handle manually
173 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
174
175 Iterator broadcast_input(broadcast_tensor, broadcast_win);
176 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
177 Iterator output(out, win);
178
179 using ExactTagType = typename wrapper::traits::neon_vector<T, window_step_x>::tag_type;
180
181 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100182 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100183 const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
184 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100185
Sheri Zhanga449a362020-07-16 15:52:25 +0100186 const auto broadcast_value = *reinterpret_cast<const T *>(broadcast_input.ptr());
187 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100188
Sheri Zhanga449a362020-07-16 15:52:25 +0100189 // Compute window_step_x elements per iteration
190 int x = window_start_x;
191 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100192 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100193 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100194
Sheri Zhanga449a362020-07-16 15:52:25 +0100195 // Dequantize inputs
196 const float32x4x4_t in1_f32x4x4 = vdequantize(non_broadcast_v, non_broadcast_qinfo);
197 const float32x4x4_t in2_f32x4x4 = vdequantize(broadcast_value_vec, broadcast_qinfo);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100198
Sheri Zhanga449a362020-07-16 15:52:25 +0100199 const float32x4x4_t out_f32x4x4 =
200 {
201 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
202 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
203 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
204 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
205 };
206
207 // Quantize output
208 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
209 wrapper::vstore(output_ptr + x, result);
210 }
211
212 // Compute left-over elements
213 for(; x < window_end_x; ++x)
214 {
215 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100216 const T in1 = *(non_broadcast_input_ptr + x);
217 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, non_broadcast_qinfo);
218 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(broadcast_value, broadcast_qinfo);
219 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100220
221 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100222 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100223 *(output_ptr + x) = tmp_qua;
224 }
225 },
226 broadcast_input, non_broadcast_input, output);
227 }
228 else
229 {
230 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
231 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
232
233 // Clear X Dimension on execution window as we handle manually
234 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
235 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
236
237 Iterator input1(in1, input1_win);
238 Iterator input2(in2, input2_win);
239 Iterator output(out, win);
240
241 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100242 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100243 const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
244 const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
245 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100246
Sheri Zhanga449a362020-07-16 15:52:25 +0100247 // Compute window_step_x elements per iteration
248 int x = window_start_x;
249 for(; x <= (window_end_x - window_step_x); x += window_step_x)
250 {
251 const auto input1_q = wrapper::vloadq(input1_ptr + x);
252 const auto input2_q = wrapper::vloadq(input2_ptr + x);
253
254 // Dequantize inputs
255 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
256 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
257
258 const float32x4x4_t out_f32x4x4 =
259 {
260 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
261 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
262 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
263 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
264 };
265
266 // Quantize output
267 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
268 wrapper::vstore(output_ptr + x, result);
269 }
270
271 // Compute left-over elements
272 for(; x < window_end_x; ++x)
273 {
274 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100275 const T in1 = *(input1_ptr + x);
276 const T in2 = *(input2_ptr + x);
277 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, input1_qua_info);
278 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(in2, input2_qua_info);
279 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100280
281 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100282 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100283 *(output_ptr + x) = tmp_qua;
284 }
285 },
286 input1, input2, output);
287 }
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100288}
289
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100290void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
291{
292 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
293 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
294 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
295
296 // Create input windows
297 Window win = window;
298 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
299 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
300
301 // Clear X Dimension on execution window as we handle manually
302 win.set(Window::DimX, Window::Dimension(0, 1, 1));
303 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
304 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
305
306 Iterator input1(in1, input1_win);
307 Iterator input2(in2, input2_win);
308 Iterator output(out, win);
309
310 const int window_step_x = 16;
311 const auto window_start_x = static_cast<int>(window.x().start());
312 const auto window_end_x = static_cast<int>(window.x().end());
313
314 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
315
316 execute_window_loop(win, [&](const Coordinates &)
317 {
318 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
319 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
320 const auto output_ptr = reinterpret_cast<qsymm16_t *>(output.ptr());
321
322 // Compute window_step_x elements per iteration
323 int x = window_start_x;
324 for(; x <= (window_end_x - window_step_x); x += window_step_x)
325 {
326 const qsymm16x8x2_t input1_q =
327 {
328 {
329 vld1q_s16(input1_ptr + x),
330 vld1q_s16(input1_ptr + x + 8),
331 }
332 };
333 const qsymm16x8x2_t input2_q =
334 {
335 {
336 vld1q_s16(input2_ptr + x),
337 vld1q_s16(input2_ptr + x + 8),
338 }
339 };
340
341 // Dequantize inputs
342 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
343 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
344
345 const float32x4x4_t out_f32x4x4 =
346 {
347 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
348 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
349 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
350 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
351 };
352
353 const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
354 vst1q_s16(output_ptr + x, result.val[0]);
355 vst1q_s16(output_ptr + x + 8, result.val[1]);
356 }
357
358 // Compute left-over elements
359 for(; x < window_end_x; ++x)
360 {
361 // Dequantize inputs
362 float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
363 float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
364 float tmp_f = tmp_in1 * tmp_in2;
365
366 // Quantize output, lrintf() has same rounding mode as vcombine_s16
367 int32_t tmp = lrintf(tmp_f / tmp_qua_info.scale);
368 qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
369 *(output_ptr + x) = tmp_qua;
370 }
371 },
372 input1, input2, output);
373}
374
375void mul_QSYMM16_QSYMM16_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int scale)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100376{
377 ARM_COMPUTE_UNUSED(scale);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100378
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100379 // Create input windows
380 Window win = window;
381 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
382 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100383
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100384 // Clear X Dimension on execution window as we handle manually
385 win.set(Window::DimX, Window::Dimension(0, 1, 1));
386 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
387 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100388
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100389 Iterator input1(in1, input1_win);
390 Iterator input2(in2, input2_win);
391 Iterator output(out, win);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100392
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100393 const int window_step_x = 16;
394 const auto window_start_x = static_cast<int>(window.x().start());
395 const auto window_end_x = static_cast<int>(window.x().end());
396
397 execute_window_loop(win, [&](const Coordinates &)
398 {
399 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
400 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
401 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
402
403 // Compute window_step_x elements per iteration
404 int x = window_start_x;
405 for(; x <= (window_end_x - window_step_x); x += window_step_x)
406 {
407 const qsymm16x8x2_t input1_q =
408 {
409 {
410 vld1q_s16(input1_ptr + x),
411 vld1q_s16(input1_ptr + x + 8),
412 }
413 };
414 const qsymm16x8x2_t input2_q =
415 {
416 {
417 vld1q_s16(input2_ptr + x),
418 vld1q_s16(input2_ptr + x + 8),
419 }
420 };
421
422 const int32x4x4_t in1_s32 =
423 {
424 {
425 vmovl_s16(vget_low_s16(input1_q.val[0])),
426 vmovl_s16(vget_high_s16(input1_q.val[0])),
427 vmovl_s16(vget_low_s16(input1_q.val[1])),
428 vmovl_s16(vget_high_s16(input1_q.val[1])),
429 }
430 };
431 const int32x4x4_t in2_s32 =
432 {
433 {
434 vmovl_s16(vget_low_s16(input2_q.val[0])),
435 vmovl_s16(vget_high_s16(input2_q.val[0])),
436 vmovl_s16(vget_low_s16(input2_q.val[1])),
437 vmovl_s16(vget_high_s16(input2_q.val[1])),
438 }
439 };
440
441 const int32x4x4_t result =
442 {
443 {
444 vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
445 vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
446 vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
447 vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
448 }
449 };
450
451 vst1q_s32(output_ptr + x, result.val[0]);
452 vst1q_s32(output_ptr + x + 4, result.val[1]);
453 vst1q_s32(output_ptr + x + 8, result.val[2]);
454 vst1q_s32(output_ptr + x + 12, result.val[3]);
455 }
456
457 // Compute left-over elements
458 for(; x < window_end_x; ++x)
459 {
460 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
461 *(output_ptr + x) = tmp;
462 }
463 },
464 input1, input2, output);
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100465}
466
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100467template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100468void mul_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100470 // Create input windows
471 Window win = window;
472 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
473 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100474
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100475 // Clear X Dimension on execution window as we handle manually
476 win.set(Window::DimX, Window::Dimension(0, 1, 1));
477 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
478 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100479
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100480 Iterator input1(in1, input1_win);
481 Iterator input2(in2, input2_win);
482 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100484 const int window_step_x = 16 / sizeof(uint8_t);
485 const auto window_start_x = static_cast<int>(window.x().start());
486 const auto window_end_x = static_cast<int>(window.x().end());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100487
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100488 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100489 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100490 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
491 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
492 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100493
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100494 // Compute window_step_x elements per iteration
495 int x = window_start_x;
496 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100497 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100498 const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
499 const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100501 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
502 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
503 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
504 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
505
506 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
507 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
508
509 if(is_scale255)
510 {
511 tmp1_high = scale255_U16_U16(tmp1_high);
512 tmp1_low = scale255_U16_U16(tmp1_low);
513 }
514 else
515 {
516 const int16x8_t vn = vdupq_n_s16(-n);
517
518 if(is_sat)
519 {
520 tmp1_high = vqshlq_u16(tmp1_high, vn);
521 tmp1_low = vqshlq_u16(tmp1_low, vn);
522 }
523 else
524 {
525 tmp1_high = vshlq_u16(tmp1_high, vn);
526 tmp1_low = vshlq_u16(tmp1_low, vn);
527 }
528 }
529 if(is_sat)
530 {
531 vst1q_u8(output_ptr, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
532 }
533 else
534 {
535 vst1q_u8(output_ptr, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
536 }
537 }
538
539 // Compute left-over elements
540 for(; x < window_end_x; ++x)
541 {
542 uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
543
544 if(is_scale255)
545 {
546 float tmp_f = static_cast<float>(tmp) * scale255_constant;
547 tmp = static_cast<uint16_t>(tmp_f + 0.5f);
548 }
549 else
550 {
551 tmp >>= n;
552 }
553 if(is_sat && tmp > 255)
554 {
555 tmp = 255;
556 }
557 *(output_ptr + x) = static_cast<uint8_t>(tmp);
558 }
559 },
560 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100561}
562
563template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100564inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
565{
566 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
567 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
568 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
569 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
570
571 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
572 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
573
574 if(is_scale255)
575 {
576 tmp1_high = scale255_S32_S32(tmp1_high);
577 tmp1_low = scale255_S32_S32(tmp1_low);
578 }
579 else
580 {
581 // Right shift amount
582 const int32x4_t vn = vdupq_n_s32(-n);
583 // Left shift amount
584 const int32x4_t vnl = vdupq_n_s32(n);
585 // Calculate conversion bit
586 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
587 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
588 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
589 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
590 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
591 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
592 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
593 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
594 if(is_sat)
595 {
596 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
597 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
598 }
599 else
600 {
601 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
602 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
603 }
604 }
605
606 if(is_sat)
607 {
608 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
609 }
610 else
611 {
612 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
613 }
614}
615
616template <bool is_scale255, bool is_sat>
617inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
618{
619 const int16x8x2_t result =
620 {
621 {
622 // First 8 elements
623 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
624 // Second 8 elements
625 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
626 }
627 };
628
629 return result;
630}
631
632template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100633void mul_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100635 // Create input windows
636 Window win = window;
637 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
638 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100639
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100640 // Clear X Dimension on execution window as we handle manually
641 win.set(Window::DimX, Window::Dimension(0, 1, 1));
642 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
643 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100644
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100645 Iterator input1(in1, input1_win);
646 Iterator input2(in2, input2_win);
647 Iterator output(out, win);
648
649 const int window_step_x = 16;
650 const auto window_start_x = static_cast<int>(window.x().start());
651 const auto window_end_x = static_cast<int>(window.x().end());
652
653 execute_window_loop(win, [&](const Coordinates &)
654 {
655 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
656 const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
657 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
658
659 // Compute window_step_x elements per iteration
660 int x = window_start_x;
661 for(; x <= (window_end_x - window_step_x); x += window_step_x)
662 {
663 const int16x8x2_t ta1 =
664 {
665 {
666 vld1q_s16(input1_ptr + x),
667 vld1q_s16(input1_ptr + x + 8),
668 }
669 };
670 const int16x8x2_t ta2 =
671 {
672 {
673 vld1q_s16(input2_ptr + x),
674 vld1q_s16(input2_ptr + x + 8),
675 }
676 };
677 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
678
679 vst1q_s16(output_ptr + x, result.val[0]);
680 vst1q_s16(output_ptr + x + 8, result.val[1]);
681 }
682
683 // Compute left-over elements
684 for(; x < window_end_x; ++x)
685 {
686 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
687
688 if(is_scale255)
689 {
690 float tmp_f = static_cast<float>(tmp) * scale255_constant;
691
692 tmp = static_cast<int32_t>(tmp_f + 0.5f);
693 }
694 else
695 {
696 if(tmp >= 0)
697 {
698 tmp >>= n;
699 }
700 else
701 {
702 uint32_t mask = (1u << n) - 1;
703 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
704 }
705 }
706 if(is_sat)
707 {
708 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
709 }
710 *(output_ptr + x) = static_cast<int16_t>(tmp);
711 }
712 },
713 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100714}
715
SiCong Libb88f892020-08-28 11:18:47 +0100716template <bool is_sat>
717inline int32x4_t mul_S32_S32_S32_n_loop(const int32x4_t &input1, const int32x4_t &input2, int n)
718{
719 const int32x2_t input1_1 = vget_low_s32(input1);
720 const int32x2_t input2_1 = vget_low_s32(input2);
721 const int32x2_t input1_2 = vget_high_s32(input1);
722 const int32x2_t input2_2 = vget_high_s32(input2);
723
724 int64x2_t tmp_1 = vmull_s32(input1_1, input2_1);
725 int64x2_t tmp_2 = vmull_s32(input1_2, input2_2);
726
727 // Apply scaling, conversion and rounding (round to zero)
728 // Right shift amount
729 const int64x2_t vn = vdupq_n_s64(-n);
730 // Left shift amount
731 const int64x2_t vnl = vdupq_n_s64(n);
732 // Calculate conversion bit
733 const uint64x2_t tmp_1_u = vreinterpretq_u64_s64(tmp_1);
734 const uint64x2_t sign_1 = vshrq_n_u64(tmp_1_u, 63);
735 const int64x2_t sign_1_s = vreinterpretq_s64_u64(sign_1);
736 const int64x2_t convert_1 = vsubq_s64(vshlq_s64(sign_1_s, vnl), sign_1_s);
737
738 const uint64x2_t tmp_2_u = vreinterpretq_u64_s64(tmp_2);
739 const uint64x2_t sign_2 = vshrq_n_u64(tmp_2_u, 63);
740 const int64x2_t sign_2_s = vreinterpretq_s64_u64(sign_2);
741 const int64x2_t convert_2 = vsubq_s64(vshlq_s64(sign_2_s, vnl), sign_2_s);
742 if(is_sat)
743 {
744 tmp_1 = vqshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
745 tmp_2 = vqshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
746 return vcombine_s32(vqmovn_s64(tmp_1), vqmovn_s64(tmp_2));
747 }
748 else
749 {
750 tmp_1 = vshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
751 tmp_2 = vshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
752 return vcombine_s32(vmovn_s64(tmp_1), vmovn_s64(tmp_2));
753 }
754}
755
756template <bool is_sat>
757inline int32x4x2_t mul_S32_S32_S32_n_k(const int32x4x2_t &input1, const int32x4x2_t &input2, int n)
758{
759 const int32x4x2_t result =
760 {
761 {
762 // First 4 elements
763 mul_S32_S32_S32_n_loop<is_sat>(input1.val[0], input2.val[0], n),
764 // Second 4 elements
765 mul_S32_S32_S32_n_loop<is_sat>(input1.val[1], input2.val[1], n)
766 }
767 };
768
769 return result;
770}
771
772template <bool is_sat>
773void mul_S32_S32_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
774{
775 // Create input windows
SiCong Libb88f892020-08-28 11:18:47 +0100776 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
777 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
778
779 // Clear X Dimension on execution window as we handle manually
SiCong Lid6d1b362020-09-24 17:34:23 +0100780 Window win = window;
SiCong Libb88f892020-08-28 11:18:47 +0100781 win.set(Window::DimX, Window::Dimension(0, 1, 1));
SiCong Libb88f892020-08-28 11:18:47 +0100782
SiCong Lid6d1b362020-09-24 17:34:23 +0100783 const int window_step_x = 8;
784 const auto window_start_x = static_cast<int>(window.x().start());
785 const auto window_end_x = static_cast<int>(window.x().end());
786 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
SiCong Libb88f892020-08-28 11:18:47 +0100787
SiCong Lid6d1b362020-09-24 17:34:23 +0100788 if(is_broadcast_across_x)
SiCong Libb88f892020-08-28 11:18:47 +0100789 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100790 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
791 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
792 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
793 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
794 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
SiCong Libb88f892020-08-28 11:18:47 +0100795
SiCong Lid6d1b362020-09-24 17:34:23 +0100796 // Clear X Dimension on execution window as we handle manually
797 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
798
799 Iterator broadcast_input(broadcast_tensor, broadcast_win);
800 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
801 Iterator output(out, win);
802
803 execute_window_loop(win, [&](const Coordinates &)
SiCong Libb88f892020-08-28 11:18:47 +0100804 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100805 const auto non_broadcast_input_ptr = reinterpret_cast<const int32_t *>(non_broadcast_input.ptr());
806 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
SiCong Libb88f892020-08-28 11:18:47 +0100807
SiCong Lid6d1b362020-09-24 17:34:23 +0100808 const int32_t broadcast_value = *reinterpret_cast<const int32_t *>(broadcast_input.ptr());
809 const auto broadcast_value_vec = vdupq_n_s32(broadcast_value);
SiCong Libb88f892020-08-28 11:18:47 +0100810
SiCong Lid6d1b362020-09-24 17:34:23 +0100811 // Compute window_step_x elements per iteration
812 int x = window_start_x;
813 for(; x <= (window_end_x - window_step_x); x += window_step_x)
814 {
815 const int32x4x2_t broadcast_v =
816 {
817 {
818 broadcast_value_vec,
819 broadcast_value_vec,
820 }
821 };
822 const int32x4x2_t non_broadcast_v =
823 {
824 {
825 vld1q_s32(non_broadcast_input_ptr + x),
826 vld1q_s32(non_broadcast_input_ptr + x + 4),
827 }
828 };
829 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(broadcast_v, non_broadcast_v, n);
830
831 vst1q_s32(output_ptr + x, result.val[0]);
832 vst1q_s32(output_ptr + x + 4, result.val[1]);
833 }
834
835 // Compute left-over elements
836 for(; x < window_end_x; ++x)
837 {
838 int64_t tmp = static_cast<int64_t>(broadcast_value) * static_cast<int64_t>(*(non_broadcast_input_ptr + x));
839
840 if(tmp >= 0)
841 {
842 tmp >>= n;
843 }
844 else
845 {
846 uint64_t mask = (1u << n) - 1;
847 tmp = (tmp + static_cast<int64_t>(mask)) >> n;
848 }
849 if(is_sat)
850 {
851 tmp = utility::clamp<int64_t, int32_t>(tmp);
852 }
853 *(output_ptr + x) = static_cast<int32_t>(tmp);
854 }
855 },
856 broadcast_input, non_broadcast_input, output);
857 }
858 else
859 {
860 // Clear X Dimension on execution window as we handle manually
861 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
862 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
863
864 Iterator input1(in1, input1_win);
865 Iterator input2(in2, input2_win);
866 Iterator output(out, win);
867
868 execute_window_loop(win, [&](const Coordinates &)
SiCong Libb88f892020-08-28 11:18:47 +0100869 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100870 const auto input1_ptr = reinterpret_cast<const int32_t *>(input1.ptr());
871 const auto input2_ptr = reinterpret_cast<const int32_t *>(input2.ptr());
872 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
SiCong Libb88f892020-08-28 11:18:47 +0100873
SiCong Lid6d1b362020-09-24 17:34:23 +0100874 // Compute window_step_x elements per iteration
875 int x = window_start_x;
876 for(; x <= (window_end_x - window_step_x); x += window_step_x)
SiCong Libb88f892020-08-28 11:18:47 +0100877 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100878 const int32x4x2_t ta1 =
879 {
880 {
881 vld1q_s32(input1_ptr + x),
882 vld1q_s32(input1_ptr + x + 4),
883 }
884 };
885 const int32x4x2_t ta2 =
886 {
887 {
888 vld1q_s32(input2_ptr + x),
889 vld1q_s32(input2_ptr + x + 4),
890 }
891 };
892 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(ta1, ta2, n);
893
894 vst1q_s32(output_ptr + x, result.val[0]);
895 vst1q_s32(output_ptr + x + 4, result.val[1]);
SiCong Libb88f892020-08-28 11:18:47 +0100896 }
SiCong Lid6d1b362020-09-24 17:34:23 +0100897
898 // Compute left-over elements
899 for(; x < window_end_x; ++x)
SiCong Libb88f892020-08-28 11:18:47 +0100900 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100901 int64_t tmp = static_cast<int64_t>(*(input1_ptr + x)) * static_cast<int64_t>(*(input2_ptr + x));
902
903 if(tmp >= 0)
904 {
905 tmp >>= n;
906 }
907 else
908 {
909 uint64_t mask = (1u << n) - 1;
910 tmp = (tmp + static_cast<int64_t>(mask)) >> n;
911 }
912 if(is_sat)
913 {
914 tmp = utility::clamp<int64_t, int32_t>(tmp);
915 }
916 *(output_ptr + x) = static_cast<int32_t>(tmp);
SiCong Libb88f892020-08-28 11:18:47 +0100917 }
SiCong Lid6d1b362020-09-24 17:34:23 +0100918 },
919 input1, input2, output);
920 }
SiCong Libb88f892020-08-28 11:18:47 +0100921}
922
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100923void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100924{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100925 // Create input windows
926 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
927 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100928
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100929 // Clear X Dimension on execution window as we handle manually
930 Window win = window;
931 win.set(Window::DimX, Window::Dimension(0, 1, 1));
932
933 constexpr int window_step_x = 16 / sizeof(float);
934 const auto window_start_x = static_cast<int>(window.x().start());
935 const auto window_end_x = static_cast<int>(window.x().end());
936 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
937
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100938 using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
939
940 if(is_broadcast_across_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100941 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100942 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
943 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
944 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
945 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
946 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
947
948 // Clear X Dimension on execution window as we handle manually
949 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
950
951 Iterator broadcast_input(broadcast_tensor, broadcast_win);
952 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
953 Iterator output(out, win);
954
955 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100956 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100957 const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
958 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
959
960 const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
961 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
962 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
963
964 // Compute window_step_x elements per iteration
965 int x = window_start_x;
966 for(; x <= (window_end_x - window_step_x); x += window_step_x)
967 {
968 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
969 auto res = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
970 wrapper::vstore(output_ptr + x, res);
971 }
972
973 // Compute left-over elements
974 for(; x < window_end_x; ++x)
975 {
976 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
977 *(output_ptr + x) = broadcast_value * non_broadcast_v * scale;
978 }
979 },
980 broadcast_input, non_broadcast_input, output);
981 }
982 else
983 {
984 // Clear X Dimension on execution window as we handle manually
985 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
986 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
987
988 Iterator input1(in1, input1_win);
989 Iterator input2(in2, input2_win);
990 Iterator output(out, win);
991
992 execute_window_loop(win, [&](const Coordinates &)
993 {
994 const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
995 const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
996 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
997
998 // Compute window_step_x elements per iteration
999 int x = window_start_x;
1000 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1001 {
1002 const auto ta1 = wrapper::vloadq(input1_ptr + x);
1003 const auto ta2 = wrapper::vloadq(input2_ptr + x);
1004 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
1005 const auto res = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
1006 wrapper::vstore(output_ptr + x, res);
1007 }
1008
1009 // Compute left-over elements
1010 for(; x < window_end_x; ++x)
1011 {
1012 const auto ta1 = *(input1_ptr + x);
1013 const auto ta2 = *(input2_ptr + x);
1014 *(output_ptr + x) = ta1 * ta2 * scale;
1015 }
1016 },
1017 input1, input2, output);
1018 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001019}
1020
giuros01154bc1c2019-03-26 17:44:40 +00001021void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr)
1022{
1023 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
1024 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
1025 const auto output = static_cast<float *__restrict>(output_ptr);
1026
1027 const float32x4_t a = wrapper::vloadq(input1);
1028 float32x4_t b = wrapper::vloadq(input2);
1029
1030 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
1031
1032 const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
1033 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1034 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1035 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1036 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1037
1038 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1039 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1040
1041 float32x4_t res = wrapper::vmul(tmp0, b);
1042
1043 b = wrapper::vrev64(b);
1044 b = wrapper::vmul(b, mask);
1045
1046 res = wrapper::vmla(res, tmp1, b);
1047 wrapper::vstore(output, res);
1048}
1049
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001050#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001051void mul_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
1052{
1053 // Create input windows
1054 Window win = window;
1055 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1056 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
1057
1058 // Clear X Dimension on execution window as we handle manually
1059 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1060 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1061 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1062
1063 Iterator input1(in1, input1_win);
1064 Iterator input2(in2, input2_win);
1065 Iterator output(out, win);
1066
1067 const int window_step_x = 16;
1068 const auto window_start_x = static_cast<int>(window.x().start());
1069 const auto window_end_x = static_cast<int>(window.x().end());
1070
1071 execute_window_loop(win, [&](const Coordinates &)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001072 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001073 const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
1074 const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
1075 const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
1076
1077 // Compute window_step_x elements per iteration
1078 int x = window_start_x;
1079 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001080 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001081 const float16x8x2_t ta1 =
1082 {
1083 {
1084 vld1q_f16(input1_ptr + x),
1085 vld1q_f16(input1_ptr + x + 8),
1086 }
1087 };
1088 const float16x8x2_t ta2 =
1089 {
1090 {
1091 vld1q_f16(input2_ptr + x),
1092 vld1q_f16(input2_ptr + x + 8),
1093 }
1094 };
1095 const float16x8_t scale_vec = vdupq_n_f16(scale);
1096 const float16x8x2_t result =
1097 {
1098 {
1099 vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
1100 vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
1101 }
1102 };
1103 vst1q_f16(output_ptr + x, result.val[0]);
1104 vst1q_f16(output_ptr + x + 8, result.val[1]);
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001105 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001106
1107 // Compute left-over elements
1108 for(; x < window_end_x; ++x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001109 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001110 const auto ta1 = *(input1_ptr + x);
1111 const auto ta2 = *(input2_ptr + x);
1112 *(output_ptr + x) = ta1 * ta2 * scale;
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001113 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001114 },
1115 input1, input2, output);
1116}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001117#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +01001118
1119template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001120void mul_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001121{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001122 // Create input windows
1123 Window win = window;
1124 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1125 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001126
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001127 // Clear X Dimension on execution window as we handle manually
1128 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1129 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1130 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001131
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001132 Iterator input1(in1, input1_win);
1133 Iterator input2(in2, input2_win);
1134 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001135
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001136 const int window_step_x = 16 / sizeof(uint8_t);
1137 const auto window_start_x = static_cast<int>(window.x().start());
1138 const auto window_end_x = static_cast<int>(window.x().end());
1139
1140 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001141 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001142 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
1143 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1144 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001145
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001146 // Compute window_step_x elements per iteration
1147 int x = window_start_x;
1148 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001149 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001150 const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
1151 const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
1152
1153 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
1154 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
1155 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
1156 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
1157
1158 if(is_scale255)
1159 {
1160 tmp_low = scale255_U16_U16(tmp_low);
1161 tmp_high = scale255_U16_U16(tmp_high);
1162 }
1163 else
1164 {
1165 const int16x8_t vn = vdupq_n_s16(-n);
1166
1167 if(is_sat)
1168 {
1169 tmp_low = vqshlq_u16(tmp_low, vn);
1170 tmp_high = vqshlq_u16(tmp_high, vn);
1171 }
1172 else
1173 {
1174 tmp_low = vshlq_u16(tmp_low, vn);
1175 tmp_high = vshlq_u16(tmp_high, vn);
1176 }
1177 }
1178
1179 if(is_sat)
1180 {
1181 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
1182
1183 tmp_low = vminq_u16(tmp_low, max);
1184 tmp_high = vminq_u16(tmp_high, max);
1185 }
1186
1187 vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
1188 vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001189 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001190
1191 // Compute left-over elements
1192 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001193 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001194 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1195
1196 if(is_scale255)
1197 {
1198 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1199 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1200 }
1201 else
1202 {
1203 tmp >>= n;
1204 }
1205
1206 if(is_sat)
1207 {
1208 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
1209 }
1210
1211 *(output_ptr + x) = static_cast<int16_t>(tmp);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001212 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001213 },
1214 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001215}
1216
1217template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001218void mul_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001219{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001220 // Create input windows
1221 Window win = window;
1222 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1223 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001224
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001225 // Clear X Dimension on execution window as we handle manually
1226 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1227 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1228 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001229
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001230 Iterator input1(in1, input1_win);
1231 Iterator input2(in2, input2_win);
1232 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001233
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001234 const int window_step_x = 16;
1235 const auto window_start_x = static_cast<int>(window.x().start());
1236 const auto window_end_x = static_cast<int>(window.x().end());
1237
1238 execute_window_loop(win, [&](const Coordinates &)
1239 {
1240 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
1241 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1242 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
1243
1244 // Compute window_step_x elements per iteration
1245 int x = window_start_x;
1246 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1247 {
1248 const int16x8x2_t ta1 =
1249 {
1250 {
1251 vld1q_s16(input1_ptr + x),
1252 vld1q_s16(input1_ptr + x + 8),
1253 }
1254 };
1255 const uint8x8x2_t ta2u =
1256 {
1257 {
1258 vld1_u8(input2_ptr + x),
1259 vld1_u8(input2_ptr + x + 8),
1260 }
1261 };
1262 const int16x8x2_t ta2 =
1263 {
1264 {
1265 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1266 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1267 }
1268 };
1269
1270 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1271
1272 vst1q_s16(output_ptr + x, result.val[0]);
1273 vst1q_s16(output_ptr + x + 8, result.val[1]);
1274 }
1275
1276 // Compute left-over elements
1277 for(; x < window_end_x; ++x)
1278 {
1279 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1280
1281 if(is_scale255)
1282 {
1283 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1284
1285 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1286 }
1287 else
1288 {
1289 if(tmp >= 0)
1290 {
1291 tmp >>= n;
1292 }
1293 else
1294 {
1295 uint32_t mask = (1u << n) - 1;
1296 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
1297 }
1298 }
1299 if(is_sat)
1300 {
1301 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1302 }
1303 *(output_ptr + x) = static_cast<int16_t>(tmp);
1304 }
1305 },
1306 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001307}
1308
1309template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001310void mul_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001311{
1312 // Simply swap the two input buffers
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001313 mul_S16_U8_S16<is_scale255, is_sat>(in2, in1, out, window, n);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001314}
1315} // namespace
1316
1317NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001318 : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001319{
1320}
1321
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001322void NEPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001323{
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001324 ARM_COMPUTE_UNUSED(rounding_policy);
Georgios Pinitasf0dea702017-07-03 18:17:28 +01001325 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1326
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001327 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001328
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001329 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001330 const TensorShape &out_shape = broadcast_pair.first;
1331 const ValidRegion &valid_region = broadcast_pair.second;
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001332
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001333 // Auto initialize output if not initialized
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001334 set_shape_if_empty(*output, out_shape);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001335
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001336 _scale = scale;
1337 _scale_exponent = 0;
1338 _func_quantized = nullptr;
1339 _func_int = nullptr;
1340 _func_float = nullptr;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001341
1342 bool is_scale_255 = false;
1343 // Check and validate scaling factor
1344 if(std::abs(scale - scale255_constant) < 0.00001f)
1345 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001346 is_scale_255 = true;
1347 }
1348 else
1349 {
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001350 int exponent = 0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001351
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001352 std::frexp(scale, &exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001353
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001354 // Store the positive exponent. We know that we compute 1/2^n
1355 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1356 _scale_exponent = std::abs(exponent - 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001357 }
1358
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001359 const DataType dt_input1 = input1->data_type();
1360 const DataType dt_input2 = input2->data_type();
1361 const DataType dt_output = output->data_type();
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001362 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
1363
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001364 switch(dt_input1)
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001365 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001366 case DataType::QASYMM8:
1367 if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1368 {
1369 _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1370 }
1371 break;
1372 case DataType::QASYMM8_SIGNED:
1373 if(dt_input2 == DataType::QASYMM8_SIGNED)
1374 {
1375 _func_quantized = &mul_saturate_quantized_8<int8_t>;
1376 ;
1377 }
1378 break;
1379 case DataType::QSYMM16:
1380 if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1381 {
1382 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1383 }
1384 else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1385 {
1386 _func_int = &mul_QSYMM16_QSYMM16_S32;
1387 }
1388 break;
1389 case DataType::S16:
1390 if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1391 {
1392 if(is_scale_255)
1393 {
1394 _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1395 }
1396 else
1397 {
1398 _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1399 }
1400 }
1401 if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1402 {
1403 if(is_scale_255)
1404 {
1405 _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1406 }
1407 else
1408 {
1409 _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1410 }
1411 }
1412 break;
SiCong Libb88f892020-08-28 11:18:47 +01001413 case DataType::S32:
1414 if(DataType::S32 == dt_input2 && DataType::S32 == dt_output)
1415 {
1416 _func_int = is_sat ? &mul_S32_S32_S32<true> : &mul_S32_S32_S32<false>;
1417 }
1418 break;
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001419 case DataType::U8:
1420 if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1421 {
1422 if(is_scale_255)
1423 {
1424 _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1425 }
1426 else
1427 {
1428 _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1429 }
1430 }
1431 else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1432 {
1433 if(is_scale_255)
1434 {
1435 _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1436 }
1437 else
1438 {
1439 _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1440 }
1441 }
1442 else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1443 {
1444 if(is_scale_255)
1445 {
1446 _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1447 }
1448 else
1449 {
1450 _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1451 }
1452 }
1453 break;
1454#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1455 case DataType::F16:
1456 _func_float = &mul_F16_F16_F16;
1457 break;
1458#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1459 case DataType::F32:
1460 _func_float = &mul_F32_F32_F32;
1461 break;
1462 default:
1463 ARM_COMPUTE_ERROR("You called with the wrong img formats");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001464 }
1465
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001466 // Configure kernel window
1467 Coordinates coord;
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001468 coord.set_num_dimensions(output->num_dimensions());
1469 output->set_valid_region(valid_region);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001470 Window win = calculate_max_window(valid_region, Steps());
1471
1472 INEKernel::configure(win);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001473}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001474
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001475Status NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
1476 RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001477{
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001478 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001479 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001480
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001481 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001482}
1483
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001484void NEPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001485{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001486 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001487 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1488 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1489
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001490 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1491 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1492 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001493
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001494 if(_func_quantized != nullptr)
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001495 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001496 (*_func_quantized)(input1, input2, output, window, _scale);
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001497 }
1498 else if(_func_int != nullptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001499 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001500 (*_func_int)(input1, input2, output, window, _scale_exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001501 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001502 else
1503 {
1504 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001505 (*_func_float)(input1, input2, output, window, _scale);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001506 }
1507}
giuros01154bc1c2019-03-26 17:44:40 +00001508namespace
1509{
1510constexpr unsigned int num_elems_processed_per_iteration_complex = 2;
1511
1512Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1513{
1514 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32);
1515 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32);
1516
1517 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
1518
1519 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1520
1521 // Validate in case of configured output
1522 if(output->total_size() > 0)
1523 {
1524 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32);
1525 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
1526 }
1527
1528 return Status{};
1529}
1530
1531std::pair<Status, Window> validate_and_configure_window_complex(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
1532{
1533 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
1534 const TensorShape &out_shape = broadcast_pair.first;
1535 const ValidRegion &valid_region = broadcast_pair.second;
1536
1537 // Auto initialize output if not initialized
1538 const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type());
1539 auto_init_if_empty(*output, out_info);
1540
1541 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration_complex));
1542 Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
1543 Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
1544
1545 AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration_complex);
1546 AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_complex);
1547 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_complex);
1548
1549 bool window_changed = update_window_and_padding(win_input1, input1_access)
1550 || update_window_and_padding(win_input2, input2_access)
1551 || update_window_and_padding(win, output_access);
1552
1553 output_access.set_valid_region(win, valid_region);
1554
1555 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1556 return std::make_pair(err, win);
1557}
1558} // namespace
1559
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001560void NEComplexPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
giuros01154bc1c2019-03-26 17:44:40 +00001561{
1562 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001563 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1, input2, output));
giuros01154bc1c2019-03-26 17:44:40 +00001564
1565 // Configure kernel window
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001566 auto win_config = validate_and_configure_window_complex(input1, input2, output);
giuros01154bc1c2019-03-26 17:44:40 +00001567 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1568
giuros01154bc1c2019-03-26 17:44:40 +00001569 // Create kernel
1570 INEKernel::configure(win_config.second);
1571}
1572
1573Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1574{
1575 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1576 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output));
1577 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_complex(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
1578
1579 return Status{};
1580}
1581
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001582void NEComplexPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
giuros01154bc1c2019-03-26 17:44:40 +00001583{
1584 ARM_COMPUTE_UNUSED(info);
1585 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1586 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1587
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001588 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1589 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1590 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001591
1592 Iterator input1_it(input1, window.broadcast_if_dimension_le_one(input1->info()->tensor_shape()));
1593 Iterator input2_it(input2, window.broadcast_if_dimension_le_one(input2->info()->tensor_shape()));
1594 Iterator output_it(output, window);
giuros01154bc1c2019-03-26 17:44:40 +00001595
1596 execute_window_loop(window, [&](const Coordinates &)
1597 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001598 c_mul_F32_F32_F32_n(input1_it.ptr(), input2_it.ptr(), output_it.ptr());
giuros01154bc1c2019-03-26 17:44:40 +00001599 },
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001600 input1_it, input2_it, output_it);
giuros01154bc1c2019-03-26 17:44:40 +00001601}
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001602} // namespace arm_compute