blob: c5320b9dbfb188c845a5fbbb2c7d1f0af8ad3a46 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2016-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/TensorInfo.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010028#include "src/core/NEON/NEAsymm.h"
29#include "src/core/NEON/NESymm.h"
30#include "src/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031
32#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000034#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellodf246182017-07-03 16:25:09 +010035#include <arm_fp16.h> // needed for float16_t
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000036#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +010037
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038namespace arm_compute
39{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040namespace
41{
42const float scale255_constant = 1.f / 255.f;
43const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
44const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
45
Georgios Pinitas631c41a2017-12-06 11:53:03 +000046inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000047{
48 ARM_COMPUTE_UNUSED(overflow_policy);
49 ARM_COMPUTE_UNUSED(rounding_policy);
50
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
SiCong Libb88f892020-08-28 11:18:47 +010052 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
53 DataType::F32);
54 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
55 DataType::F32);
Michele Di Giorgio9428a182020-03-30 14:10:20 +010056 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
57 DataType::S16, DataType::QSYMM16,
58 DataType::S32, DataType::F16, DataType::F32);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000059 if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
Pablo Tello52ea9c22019-12-10 11:28:53 +000060 {
61 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000062 ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
Pablo Tello52ea9c22019-12-10 11:28:53 +000063 }
Manuel Bottini79fa9a22019-02-22 17:54:22 +000064
65 if(output->total_size() > 0)
66 {
Manuel Bottini79fa9a22019-02-22 17:54:22 +000067 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
68 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
69 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
SiCong Libb88f892020-08-28 11:18:47 +010070 // clang-format off
71 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
72 !(input1->data_type() == input2->data_type() && input2->data_type() == output->data_type()) &&
73 !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
74 !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) &&
75 !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
76 !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
77 !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::S32)
78 , "Invalid data type combination");
79 // clang-format on
80 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S16 && output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
Manuel Bottini79fa9a22019-02-22 17:54:22 +000081 }
Michalis Spyrou861f0db2018-02-26 16:47:58 +000082
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000083 if(std::abs(scale - scale255_constant) < 0.00001f)
84 {
85 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
SiCong Libb88f892020-08-28 11:18:47 +010086 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32,
87 "Scale == 1/255 is not supported if input and output are of data type S32");
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000088 }
89 else
90 {
91 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
92
93 int exponent = 0;
94 const float normalized_mantissa = std::frexp(scale, &exponent);
95
96 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
97 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
98 // Moreover, it will be negative as we deal with 1/2^n
99 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
100 }
101
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000102 return Status{};
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000103}
104
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105/* Scales a given vector by 1/255.
106 *
107 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
108 *
109 * @param in Input vector to scale.
110 * @return Scaled output rounded to nearest (round half up).
111 */
112inline int32x4_t scale255_S32_S32(int32x4_t in)
113{
114 // Scale
115 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
116 // Round to nearest (round half up)
117 // Add +0.5 for all values
118 // Afterwards vcvt rounds toward zero
119 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
120}
121
122inline uint16x8_t scale255_U16_U16(uint16x8_t in)
123{
124 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
125 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
126 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
127}
128
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100129template <typename T>
130inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
131vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000132{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100133 return vquantize_signed(val, info);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000134}
135
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100136template <typename T>
137inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
138vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Pablo Tello52ea9c22019-12-10 11:28:53 +0000139{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100140 return vquantize(val, info);
Pablo Tello52ea9c22019-12-10 11:28:53 +0000141}
142
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100143template <typename T>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100144void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
145{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100146 // Create input windows
147 Window win = window;
148 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
149 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
150
151 // Clear X Dimension on execution window as we handle manually
152 win.set(Window::DimX, Window::Dimension(0, 1, 1));
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100153
Sheri Zhanga449a362020-07-16 15:52:25 +0100154 const int window_step_x = 16 / sizeof(T);
155 const auto window_start_x = static_cast<int>(window.x().start());
156 const auto window_end_x = static_cast<int>(window.x().end());
157 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100158
Sheri Zhanga449a362020-07-16 15:52:25 +0100159 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
160 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100161
Sheri Zhanga449a362020-07-16 15:52:25 +0100162 if(is_broadcast_across_x)
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100163 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100164 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
165 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
166 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
167 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
168 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
169 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
170 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100171
Sheri Zhanga449a362020-07-16 15:52:25 +0100172 // Clear X Dimension on execution window as we handle manually
173 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
174
175 Iterator broadcast_input(broadcast_tensor, broadcast_win);
176 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
177 Iterator output(out, win);
178
179 using ExactTagType = typename wrapper::traits::neon_vector<T, window_step_x>::tag_type;
180
181 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100182 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100183 const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
184 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100185
Sheri Zhanga449a362020-07-16 15:52:25 +0100186 const auto broadcast_value = *reinterpret_cast<const T *>(broadcast_input.ptr());
187 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100188
Sheri Zhanga449a362020-07-16 15:52:25 +0100189 // Compute window_step_x elements per iteration
190 int x = window_start_x;
191 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100192 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100193 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100194
Sheri Zhanga449a362020-07-16 15:52:25 +0100195 // Dequantize inputs
196 const float32x4x4_t in1_f32x4x4 = vdequantize(non_broadcast_v, non_broadcast_qinfo);
197 const float32x4x4_t in2_f32x4x4 = vdequantize(broadcast_value_vec, broadcast_qinfo);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100198
Sheri Zhanga449a362020-07-16 15:52:25 +0100199 const float32x4x4_t out_f32x4x4 =
200 {
201 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
202 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
203 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
204 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
205 };
206
207 // Quantize output
208 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
209 wrapper::vstore(output_ptr + x, result);
210 }
211
212 // Compute left-over elements
213 for(; x < window_end_x; ++x)
214 {
215 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100216 const T in1 = *(non_broadcast_input_ptr + x);
217 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, non_broadcast_qinfo);
218 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(broadcast_value, broadcast_qinfo);
219 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100220
221 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100222 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100223 *(output_ptr + x) = tmp_qua;
224 }
225 },
226 broadcast_input, non_broadcast_input, output);
227 }
228 else
229 {
230 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
231 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
232
233 // Clear X Dimension on execution window as we handle manually
234 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
235 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
236
237 Iterator input1(in1, input1_win);
238 Iterator input2(in2, input2_win);
239 Iterator output(out, win);
240
241 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100242 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100243 const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
244 const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
245 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100246
Sheri Zhanga449a362020-07-16 15:52:25 +0100247 // Compute window_step_x elements per iteration
248 int x = window_start_x;
249 for(; x <= (window_end_x - window_step_x); x += window_step_x)
250 {
251 const auto input1_q = wrapper::vloadq(input1_ptr + x);
252 const auto input2_q = wrapper::vloadq(input2_ptr + x);
253
254 // Dequantize inputs
255 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
256 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
257
258 const float32x4x4_t out_f32x4x4 =
259 {
260 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
261 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
262 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
263 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
264 };
265
266 // Quantize output
267 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
268 wrapper::vstore(output_ptr + x, result);
269 }
270
271 // Compute left-over elements
272 for(; x < window_end_x; ++x)
273 {
274 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100275 const T in1 = *(input1_ptr + x);
276 const T in2 = *(input2_ptr + x);
277 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, input1_qua_info);
278 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(in2, input2_qua_info);
279 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100280
281 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100282 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100283 *(output_ptr + x) = tmp_qua;
284 }
285 },
286 input1, input2, output);
287 }
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100288}
289
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100290void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
291{
292 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
293 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
294 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
295
296 // Create input windows
297 Window win = window;
298 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
299 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
300
301 // Clear X Dimension on execution window as we handle manually
302 win.set(Window::DimX, Window::Dimension(0, 1, 1));
303 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
304 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
305
306 Iterator input1(in1, input1_win);
307 Iterator input2(in2, input2_win);
308 Iterator output(out, win);
309
310 const int window_step_x = 16;
311 const auto window_start_x = static_cast<int>(window.x().start());
312 const auto window_end_x = static_cast<int>(window.x().end());
313
314 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
315
316 execute_window_loop(win, [&](const Coordinates &)
317 {
318 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
319 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
320 const auto output_ptr = reinterpret_cast<qsymm16_t *>(output.ptr());
321
322 // Compute window_step_x elements per iteration
323 int x = window_start_x;
324 for(; x <= (window_end_x - window_step_x); x += window_step_x)
325 {
326 const qsymm16x8x2_t input1_q =
327 {
328 {
329 vld1q_s16(input1_ptr + x),
330 vld1q_s16(input1_ptr + x + 8),
331 }
332 };
333 const qsymm16x8x2_t input2_q =
334 {
335 {
336 vld1q_s16(input2_ptr + x),
337 vld1q_s16(input2_ptr + x + 8),
338 }
339 };
340
341 // Dequantize inputs
342 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
343 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
344
345 const float32x4x4_t out_f32x4x4 =
346 {
347 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
348 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
349 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
350 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
351 };
352
353 const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
354 vst1q_s16(output_ptr + x, result.val[0]);
355 vst1q_s16(output_ptr + x + 8, result.val[1]);
356 }
357
358 // Compute left-over elements
359 for(; x < window_end_x; ++x)
360 {
361 // Dequantize inputs
362 float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
363 float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
364 float tmp_f = tmp_in1 * tmp_in2;
365
366 // Quantize output, lrintf() has same rounding mode as vcombine_s16
367 int32_t tmp = lrintf(tmp_f / tmp_qua_info.scale);
368 qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
369 *(output_ptr + x) = tmp_qua;
370 }
371 },
372 input1, input2, output);
373}
374
375void mul_QSYMM16_QSYMM16_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int scale)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100376{
377 ARM_COMPUTE_UNUSED(scale);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100378
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100379 // Create input windows
380 Window win = window;
381 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
382 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100383
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100384 // Clear X Dimension on execution window as we handle manually
385 win.set(Window::DimX, Window::Dimension(0, 1, 1));
386 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
387 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100388
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100389 Iterator input1(in1, input1_win);
390 Iterator input2(in2, input2_win);
391 Iterator output(out, win);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100392
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100393 const int window_step_x = 16;
394 const auto window_start_x = static_cast<int>(window.x().start());
395 const auto window_end_x = static_cast<int>(window.x().end());
396
397 execute_window_loop(win, [&](const Coordinates &)
398 {
399 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
400 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
401 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
402
403 // Compute window_step_x elements per iteration
404 int x = window_start_x;
405 for(; x <= (window_end_x - window_step_x); x += window_step_x)
406 {
407 const qsymm16x8x2_t input1_q =
408 {
409 {
410 vld1q_s16(input1_ptr + x),
411 vld1q_s16(input1_ptr + x + 8),
412 }
413 };
414 const qsymm16x8x2_t input2_q =
415 {
416 {
417 vld1q_s16(input2_ptr + x),
418 vld1q_s16(input2_ptr + x + 8),
419 }
420 };
421
422 const int32x4x4_t in1_s32 =
423 {
424 {
425 vmovl_s16(vget_low_s16(input1_q.val[0])),
426 vmovl_s16(vget_high_s16(input1_q.val[0])),
427 vmovl_s16(vget_low_s16(input1_q.val[1])),
428 vmovl_s16(vget_high_s16(input1_q.val[1])),
429 }
430 };
431 const int32x4x4_t in2_s32 =
432 {
433 {
434 vmovl_s16(vget_low_s16(input2_q.val[0])),
435 vmovl_s16(vget_high_s16(input2_q.val[0])),
436 vmovl_s16(vget_low_s16(input2_q.val[1])),
437 vmovl_s16(vget_high_s16(input2_q.val[1])),
438 }
439 };
440
441 const int32x4x4_t result =
442 {
443 {
444 vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
445 vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
446 vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
447 vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
448 }
449 };
450
451 vst1q_s32(output_ptr + x, result.val[0]);
452 vst1q_s32(output_ptr + x + 4, result.val[1]);
453 vst1q_s32(output_ptr + x + 8, result.val[2]);
454 vst1q_s32(output_ptr + x + 12, result.val[3]);
455 }
456
457 // Compute left-over elements
458 for(; x < window_end_x; ++x)
459 {
460 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
461 *(output_ptr + x) = tmp;
462 }
463 },
464 input1, input2, output);
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100465}
466
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100467template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100468void mul_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100470 // Create input windows
471 Window win = window;
472 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
473 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100474
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100475 // Clear X Dimension on execution window as we handle manually
476 win.set(Window::DimX, Window::Dimension(0, 1, 1));
477 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
478 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100479
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100480 Iterator input1(in1, input1_win);
481 Iterator input2(in2, input2_win);
482 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100484 const int window_step_x = 16 / sizeof(uint8_t);
485 const auto window_start_x = static_cast<int>(window.x().start());
486 const auto window_end_x = static_cast<int>(window.x().end());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100487
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100488 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100489 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100490 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
491 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
492 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100493
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100494 // Compute window_step_x elements per iteration
495 int x = window_start_x;
496 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100497 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100498 const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
499 const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100501 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
502 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
503 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
504 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
505
506 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
507 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
508
509 if(is_scale255)
510 {
511 tmp1_high = scale255_U16_U16(tmp1_high);
512 tmp1_low = scale255_U16_U16(tmp1_low);
513 }
514 else
515 {
516 const int16x8_t vn = vdupq_n_s16(-n);
517
518 if(is_sat)
519 {
520 tmp1_high = vqshlq_u16(tmp1_high, vn);
521 tmp1_low = vqshlq_u16(tmp1_low, vn);
522 }
523 else
524 {
525 tmp1_high = vshlq_u16(tmp1_high, vn);
526 tmp1_low = vshlq_u16(tmp1_low, vn);
527 }
528 }
529 if(is_sat)
530 {
531 vst1q_u8(output_ptr, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
532 }
533 else
534 {
535 vst1q_u8(output_ptr, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
536 }
537 }
538
539 // Compute left-over elements
540 for(; x < window_end_x; ++x)
541 {
542 uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
543
544 if(is_scale255)
545 {
546 float tmp_f = static_cast<float>(tmp) * scale255_constant;
547 tmp = static_cast<uint16_t>(tmp_f + 0.5f);
548 }
549 else
550 {
551 tmp >>= n;
552 }
553 if(is_sat && tmp > 255)
554 {
555 tmp = 255;
556 }
557 *(output_ptr + x) = static_cast<uint8_t>(tmp);
558 }
559 },
560 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100561}
562
563template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100564inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
565{
566 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
567 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
568 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
569 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
570
571 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
572 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
573
574 if(is_scale255)
575 {
576 tmp1_high = scale255_S32_S32(tmp1_high);
577 tmp1_low = scale255_S32_S32(tmp1_low);
578 }
579 else
580 {
581 // Right shift amount
582 const int32x4_t vn = vdupq_n_s32(-n);
583 // Left shift amount
584 const int32x4_t vnl = vdupq_n_s32(n);
585 // Calculate conversion bit
586 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
587 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
588 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
589 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
590 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
591 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
592 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
593 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
594 if(is_sat)
595 {
596 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
597 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
598 }
599 else
600 {
601 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
602 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
603 }
604 }
605
606 if(is_sat)
607 {
608 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
609 }
610 else
611 {
612 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
613 }
614}
615
616template <bool is_scale255, bool is_sat>
617inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
618{
619 const int16x8x2_t result =
620 {
621 {
622 // First 8 elements
623 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
624 // Second 8 elements
625 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
626 }
627 };
628
629 return result;
630}
631
632template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100633void mul_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100635 // Create input windows
636 Window win = window;
637 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
638 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100639
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100640 // Clear X Dimension on execution window as we handle manually
641 win.set(Window::DimX, Window::Dimension(0, 1, 1));
642 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
643 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100644
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100645 Iterator input1(in1, input1_win);
646 Iterator input2(in2, input2_win);
647 Iterator output(out, win);
648
649 const int window_step_x = 16;
650 const auto window_start_x = static_cast<int>(window.x().start());
651 const auto window_end_x = static_cast<int>(window.x().end());
652
653 execute_window_loop(win, [&](const Coordinates &)
654 {
655 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
656 const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
657 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
658
659 // Compute window_step_x elements per iteration
660 int x = window_start_x;
661 for(; x <= (window_end_x - window_step_x); x += window_step_x)
662 {
663 const int16x8x2_t ta1 =
664 {
665 {
666 vld1q_s16(input1_ptr + x),
667 vld1q_s16(input1_ptr + x + 8),
668 }
669 };
670 const int16x8x2_t ta2 =
671 {
672 {
673 vld1q_s16(input2_ptr + x),
674 vld1q_s16(input2_ptr + x + 8),
675 }
676 };
677 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
678
679 vst1q_s16(output_ptr + x, result.val[0]);
680 vst1q_s16(output_ptr + x + 8, result.val[1]);
681 }
682
683 // Compute left-over elements
684 for(; x < window_end_x; ++x)
685 {
686 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
687
688 if(is_scale255)
689 {
690 float tmp_f = static_cast<float>(tmp) * scale255_constant;
691
692 tmp = static_cast<int32_t>(tmp_f + 0.5f);
693 }
694 else
695 {
696 if(tmp >= 0)
697 {
698 tmp >>= n;
699 }
700 else
701 {
702 uint32_t mask = (1u << n) - 1;
703 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
704 }
705 }
706 if(is_sat)
707 {
708 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
709 }
710 *(output_ptr + x) = static_cast<int16_t>(tmp);
711 }
712 },
713 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100714}
715
SiCong Libb88f892020-08-28 11:18:47 +0100716template <bool is_sat>
717inline int32x4_t mul_S32_S32_S32_n_loop(const int32x4_t &input1, const int32x4_t &input2, int n)
718{
719 const int32x2_t input1_1 = vget_low_s32(input1);
720 const int32x2_t input2_1 = vget_low_s32(input2);
721 const int32x2_t input1_2 = vget_high_s32(input1);
722 const int32x2_t input2_2 = vget_high_s32(input2);
723
724 int64x2_t tmp_1 = vmull_s32(input1_1, input2_1);
725 int64x2_t tmp_2 = vmull_s32(input1_2, input2_2);
726
727 // Apply scaling, conversion and rounding (round to zero)
728 // Right shift amount
729 const int64x2_t vn = vdupq_n_s64(-n);
730 // Left shift amount
731 const int64x2_t vnl = vdupq_n_s64(n);
732 // Calculate conversion bit
733 const uint64x2_t tmp_1_u = vreinterpretq_u64_s64(tmp_1);
734 const uint64x2_t sign_1 = vshrq_n_u64(tmp_1_u, 63);
735 const int64x2_t sign_1_s = vreinterpretq_s64_u64(sign_1);
736 const int64x2_t convert_1 = vsubq_s64(vshlq_s64(sign_1_s, vnl), sign_1_s);
737
738 const uint64x2_t tmp_2_u = vreinterpretq_u64_s64(tmp_2);
739 const uint64x2_t sign_2 = vshrq_n_u64(tmp_2_u, 63);
740 const int64x2_t sign_2_s = vreinterpretq_s64_u64(sign_2);
741 const int64x2_t convert_2 = vsubq_s64(vshlq_s64(sign_2_s, vnl), sign_2_s);
742 if(is_sat)
743 {
744 tmp_1 = vqshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
745 tmp_2 = vqshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
746 return vcombine_s32(vqmovn_s64(tmp_1), vqmovn_s64(tmp_2));
747 }
748 else
749 {
750 tmp_1 = vshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
751 tmp_2 = vshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
752 return vcombine_s32(vmovn_s64(tmp_1), vmovn_s64(tmp_2));
753 }
754}
755
756template <bool is_sat>
757inline int32x4x2_t mul_S32_S32_S32_n_k(const int32x4x2_t &input1, const int32x4x2_t &input2, int n)
758{
759 const int32x4x2_t result =
760 {
761 {
762 // First 4 elements
763 mul_S32_S32_S32_n_loop<is_sat>(input1.val[0], input2.val[0], n),
764 // Second 4 elements
765 mul_S32_S32_S32_n_loop<is_sat>(input1.val[1], input2.val[1], n)
766 }
767 };
768
769 return result;
770}
771
772template <bool is_sat>
773void mul_S32_S32_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
774{
775 // Create input windows
SiCong Libb88f892020-08-28 11:18:47 +0100776 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
777 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
778
779 // Clear X Dimension on execution window as we handle manually
SiCong Lid6d1b362020-09-24 17:34:23 +0100780 Window win = window;
SiCong Libb88f892020-08-28 11:18:47 +0100781 win.set(Window::DimX, Window::Dimension(0, 1, 1));
SiCong Libb88f892020-08-28 11:18:47 +0100782
SiCong Lid6d1b362020-09-24 17:34:23 +0100783 const int window_step_x = 8;
784 const auto window_start_x = static_cast<int>(window.x().start());
785 const auto window_end_x = static_cast<int>(window.x().end());
786 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
SiCong Libb88f892020-08-28 11:18:47 +0100787
SiCong Lid6d1b362020-09-24 17:34:23 +0100788 if(is_broadcast_across_x)
SiCong Libb88f892020-08-28 11:18:47 +0100789 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100790 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
791 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
792 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
793 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
794 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
SiCong Libb88f892020-08-28 11:18:47 +0100795
SiCong Lid6d1b362020-09-24 17:34:23 +0100796 // Clear X Dimension on execution window as we handle manually
797 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
798
799 Iterator broadcast_input(broadcast_tensor, broadcast_win);
800 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
801 Iterator output(out, win);
802
803 execute_window_loop(win, [&](const Coordinates &)
SiCong Libb88f892020-08-28 11:18:47 +0100804 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100805 const auto non_broadcast_input_ptr = reinterpret_cast<const int32_t *>(non_broadcast_input.ptr());
806 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
SiCong Libb88f892020-08-28 11:18:47 +0100807
SiCong Lid6d1b362020-09-24 17:34:23 +0100808 const int32_t broadcast_value = *reinterpret_cast<const int32_t *>(broadcast_input.ptr());
809 const auto broadcast_value_vec = vdupq_n_s32(broadcast_value);
SiCong Libb88f892020-08-28 11:18:47 +0100810
SiCong Lid6d1b362020-09-24 17:34:23 +0100811 // Compute window_step_x elements per iteration
812 int x = window_start_x;
813 for(; x <= (window_end_x - window_step_x); x += window_step_x)
814 {
815 const int32x4x2_t broadcast_v =
816 {
817 {
818 broadcast_value_vec,
819 broadcast_value_vec,
820 }
821 };
822 const int32x4x2_t non_broadcast_v =
823 {
824 {
825 vld1q_s32(non_broadcast_input_ptr + x),
826 vld1q_s32(non_broadcast_input_ptr + x + 4),
827 }
828 };
829 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(broadcast_v, non_broadcast_v, n);
830
831 vst1q_s32(output_ptr + x, result.val[0]);
832 vst1q_s32(output_ptr + x + 4, result.val[1]);
833 }
834
835 // Compute left-over elements
836 for(; x < window_end_x; ++x)
837 {
838 int64_t tmp = static_cast<int64_t>(broadcast_value) * static_cast<int64_t>(*(non_broadcast_input_ptr + x));
839
840 if(tmp >= 0)
841 {
842 tmp >>= n;
843 }
844 else
845 {
846 uint64_t mask = (1u << n) - 1;
847 tmp = (tmp + static_cast<int64_t>(mask)) >> n;
848 }
849 if(is_sat)
850 {
851 tmp = utility::clamp<int64_t, int32_t>(tmp);
852 }
853 *(output_ptr + x) = static_cast<int32_t>(tmp);
854 }
855 },
856 broadcast_input, non_broadcast_input, output);
857 }
858 else
859 {
860 // Clear X Dimension on execution window as we handle manually
861 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
862 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
863
864 Iterator input1(in1, input1_win);
865 Iterator input2(in2, input2_win);
866 Iterator output(out, win);
867
868 execute_window_loop(win, [&](const Coordinates &)
SiCong Libb88f892020-08-28 11:18:47 +0100869 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100870 const auto input1_ptr = reinterpret_cast<const int32_t *>(input1.ptr());
871 const auto input2_ptr = reinterpret_cast<const int32_t *>(input2.ptr());
872 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
SiCong Libb88f892020-08-28 11:18:47 +0100873
SiCong Lid6d1b362020-09-24 17:34:23 +0100874 // Compute window_step_x elements per iteration
875 int x = window_start_x;
876 for(; x <= (window_end_x - window_step_x); x += window_step_x)
SiCong Libb88f892020-08-28 11:18:47 +0100877 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100878 const int32x4x2_t ta1 =
879 {
880 {
881 vld1q_s32(input1_ptr + x),
882 vld1q_s32(input1_ptr + x + 4),
883 }
884 };
885 const int32x4x2_t ta2 =
886 {
887 {
888 vld1q_s32(input2_ptr + x),
889 vld1q_s32(input2_ptr + x + 4),
890 }
891 };
892 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(ta1, ta2, n);
893
894 vst1q_s32(output_ptr + x, result.val[0]);
895 vst1q_s32(output_ptr + x + 4, result.val[1]);
SiCong Libb88f892020-08-28 11:18:47 +0100896 }
SiCong Lid6d1b362020-09-24 17:34:23 +0100897
898 // Compute left-over elements
899 for(; x < window_end_x; ++x)
SiCong Libb88f892020-08-28 11:18:47 +0100900 {
SiCong Lid6d1b362020-09-24 17:34:23 +0100901 int64_t tmp = static_cast<int64_t>(*(input1_ptr + x)) * static_cast<int64_t>(*(input2_ptr + x));
902
903 if(tmp >= 0)
904 {
905 tmp >>= n;
906 }
907 else
908 {
909 uint64_t mask = (1u << n) - 1;
910 tmp = (tmp + static_cast<int64_t>(mask)) >> n;
911 }
912 if(is_sat)
913 {
914 tmp = utility::clamp<int64_t, int32_t>(tmp);
915 }
916 *(output_ptr + x) = static_cast<int32_t>(tmp);
SiCong Libb88f892020-08-28 11:18:47 +0100917 }
SiCong Lid6d1b362020-09-24 17:34:23 +0100918 },
919 input1, input2, output);
920 }
SiCong Libb88f892020-08-28 11:18:47 +0100921}
922
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100923void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100924{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100925 // Create input windows
926 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
927 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100928
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100929 // Clear X Dimension on execution window as we handle manually
930 Window win = window;
931 win.set(Window::DimX, Window::Dimension(0, 1, 1));
932
933 constexpr int window_step_x = 16 / sizeof(float);
934 const auto window_start_x = static_cast<int>(window.x().start());
935 const auto window_end_x = static_cast<int>(window.x().end());
936 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
937
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100938 using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
939
940 if(is_broadcast_across_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100941 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100942 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
943 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
944 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
945 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
946 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
947
948 // Clear X Dimension on execution window as we handle manually
949 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
950
951 Iterator broadcast_input(broadcast_tensor, broadcast_win);
952 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
953 Iterator output(out, win);
954
955 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100956 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100957 const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
958 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
959
960 const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
961 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
962 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
963
964 // Compute window_step_x elements per iteration
965 int x = window_start_x;
966 for(; x <= (window_end_x - window_step_x); x += window_step_x)
967 {
968 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
969 auto res = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
970 wrapper::vstore(output_ptr + x, res);
971 }
972
973 // Compute left-over elements
974 for(; x < window_end_x; ++x)
975 {
976 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
977 *(output_ptr + x) = broadcast_value * non_broadcast_v * scale;
978 }
979 },
980 broadcast_input, non_broadcast_input, output);
981 }
982 else
983 {
984 // Clear X Dimension on execution window as we handle manually
985 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
986 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
987
988 Iterator input1(in1, input1_win);
989 Iterator input2(in2, input2_win);
990 Iterator output(out, win);
991
992 execute_window_loop(win, [&](const Coordinates &)
993 {
994 const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
995 const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
996 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
997
998 // Compute window_step_x elements per iteration
999 int x = window_start_x;
1000 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1001 {
1002 const auto ta1 = wrapper::vloadq(input1_ptr + x);
1003 const auto ta2 = wrapper::vloadq(input2_ptr + x);
1004 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
1005 const auto res = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
1006 wrapper::vstore(output_ptr + x, res);
1007 }
1008
1009 // Compute left-over elements
1010 for(; x < window_end_x; ++x)
1011 {
1012 const auto ta1 = *(input1_ptr + x);
1013 const auto ta2 = *(input2_ptr + x);
1014 *(output_ptr + x) = ta1 * ta2 * scale;
1015 }
1016 },
1017 input1, input2, output);
1018 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001019}
1020
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001021void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
giuros01154bc1c2019-03-26 17:44:40 +00001022{
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001023 // Create input windows
1024 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1025 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
giuros01154bc1c2019-03-26 17:44:40 +00001026
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001027 // Clear X Dimension on execution window as we handle manually
1028 Window win = window;
1029 win.set(Window::DimX, Window::Dimension(0, 1, 1));
giuros01154bc1c2019-03-26 17:44:40 +00001030
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001031 constexpr int window_step_x = 8 / sizeof(float);
1032 const auto window_start_x = static_cast<int>(window.x().start());
1033 const auto window_end_x = static_cast<int>(window.x().end());
1034 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
giuros01154bc1c2019-03-26 17:44:40 +00001035
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001036 if(is_broadcast_across_x)
1037 {
1038 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
1039 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
1040 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
1041 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
1042 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
giuros01154bc1c2019-03-26 17:44:40 +00001043
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001044 // Clear X Dimension on execution window as we handle manually
1045 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
giuros01154bc1c2019-03-26 17:44:40 +00001046
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001047 Iterator broadcast_input(broadcast_tensor, broadcast_win);
1048 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1049 Iterator output(out, win);
giuros01154bc1c2019-03-26 17:44:40 +00001050
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001051 execute_window_loop(win, [&](const Coordinates &)
1052 {
1053 const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
1054 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
giuros01154bc1c2019-03-26 17:44:40 +00001055
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001056 const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
1057
1058 int x = window_start_x;
1059 // Compute left-over elements
1060 for(; x < window_end_x; ++x)
1061 {
1062 const auto broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
1063 const auto broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
1064 auto res1 = broadcast_value * (broadcast_value0 - broadcast_value1);
1065 auto res2 = broadcast_value * (broadcast_value1 + broadcast_value0);
1066 *(output_ptr + 2 * x) = res1;
1067 *(output_ptr + 2 * x + 1) = res2;
1068 }
1069 },
1070 broadcast_input, non_broadcast_input, output);
1071 }
1072 else
1073 {
1074 // Clear X Dimension on execution window as we handle manually
1075 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1076 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1077
1078 Iterator input1(in1, input1_win);
1079 Iterator input2(in2, input2_win);
1080 Iterator output(out, win);
1081
1082 execute_window_loop(win, [&](const Coordinates &)
1083 {
1084 const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
1085 const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
1086 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
1087
1088 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
1089
1090 // Compute window_step_x elements per iteration
1091 int x = window_start_x;
1092 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1093 {
1094 const float32x4_t a = wrapper::vloadq(input1_ptr + 2 * x);
1095 float32x4_t b = wrapper::vloadq(input2_ptr + 2 * x);
1096
1097 const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
1098 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1099 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1100 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1101 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1102
1103 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1104 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1105
1106 float32x4_t res = wrapper::vmul(tmp0, b);
1107
1108 b = wrapper::vrev64(b);
1109 b = wrapper::vmul(b, mask);
1110
1111 res = wrapper::vmla(res, tmp1, b);
1112 wrapper::vstore(output_ptr + 2 * x, res);
1113 }
1114
1115 // Compute left-over elements
1116 for(; x < window_end_x; ++x)
1117 {
1118 const auto a0 = *(input1_ptr + 2 * x);
1119 const auto a1 = *(input1_ptr + 2 * x + 1);
1120 const auto b0 = *(input2_ptr + 2 * x);
1121 const auto b1 = *(input2_ptr + 2 * x + 1);
1122 auto res1 = a0 * b0 - a1 * b1;
1123 auto res2 = a0 * b1 + a1 * b0;
1124 *(output_ptr + 2 * x) = res1;
1125 *(output_ptr + 2 * x + 1) = res2;
1126 }
1127 },
1128 input1, input2, output);
1129 }
giuros01154bc1c2019-03-26 17:44:40 +00001130}
1131
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001132#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001133void mul_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
1134{
1135 // Create input windows
1136 Window win = window;
1137 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1138 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
1139
1140 // Clear X Dimension on execution window as we handle manually
1141 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1142 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1143 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1144
1145 Iterator input1(in1, input1_win);
1146 Iterator input2(in2, input2_win);
1147 Iterator output(out, win);
1148
1149 const int window_step_x = 16;
1150 const auto window_start_x = static_cast<int>(window.x().start());
1151 const auto window_end_x = static_cast<int>(window.x().end());
1152
1153 execute_window_loop(win, [&](const Coordinates &)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001154 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001155 const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
1156 const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
1157 const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
1158
1159 // Compute window_step_x elements per iteration
1160 int x = window_start_x;
1161 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001162 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001163 const float16x8x2_t ta1 =
1164 {
1165 {
1166 vld1q_f16(input1_ptr + x),
1167 vld1q_f16(input1_ptr + x + 8),
1168 }
1169 };
1170 const float16x8x2_t ta2 =
1171 {
1172 {
1173 vld1q_f16(input2_ptr + x),
1174 vld1q_f16(input2_ptr + x + 8),
1175 }
1176 };
1177 const float16x8_t scale_vec = vdupq_n_f16(scale);
1178 const float16x8x2_t result =
1179 {
1180 {
1181 vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
1182 vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
1183 }
1184 };
1185 vst1q_f16(output_ptr + x, result.val[0]);
1186 vst1q_f16(output_ptr + x + 8, result.val[1]);
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001187 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001188
1189 // Compute left-over elements
1190 for(; x < window_end_x; ++x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001191 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001192 const auto ta1 = *(input1_ptr + x);
1193 const auto ta2 = *(input2_ptr + x);
1194 *(output_ptr + x) = ta1 * ta2 * scale;
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001195 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001196 },
1197 input1, input2, output);
1198}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001199#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +01001200
1201template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001202void mul_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001203{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001204 // Create input windows
1205 Window win = window;
1206 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1207 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001208
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001209 // Clear X Dimension on execution window as we handle manually
1210 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1211 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1212 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001213
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001214 Iterator input1(in1, input1_win);
1215 Iterator input2(in2, input2_win);
1216 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001217
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001218 const int window_step_x = 16 / sizeof(uint8_t);
1219 const auto window_start_x = static_cast<int>(window.x().start());
1220 const auto window_end_x = static_cast<int>(window.x().end());
1221
1222 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001223 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001224 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
1225 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1226 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001227
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001228 // Compute window_step_x elements per iteration
1229 int x = window_start_x;
1230 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001231 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001232 const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
1233 const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
1234
1235 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
1236 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
1237 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
1238 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
1239
1240 if(is_scale255)
1241 {
1242 tmp_low = scale255_U16_U16(tmp_low);
1243 tmp_high = scale255_U16_U16(tmp_high);
1244 }
1245 else
1246 {
1247 const int16x8_t vn = vdupq_n_s16(-n);
1248
1249 if(is_sat)
1250 {
1251 tmp_low = vqshlq_u16(tmp_low, vn);
1252 tmp_high = vqshlq_u16(tmp_high, vn);
1253 }
1254 else
1255 {
1256 tmp_low = vshlq_u16(tmp_low, vn);
1257 tmp_high = vshlq_u16(tmp_high, vn);
1258 }
1259 }
1260
1261 if(is_sat)
1262 {
1263 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
1264
1265 tmp_low = vminq_u16(tmp_low, max);
1266 tmp_high = vminq_u16(tmp_high, max);
1267 }
1268
1269 vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
1270 vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001271 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001272
1273 // Compute left-over elements
1274 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001275 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001276 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1277
1278 if(is_scale255)
1279 {
1280 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1281 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1282 }
1283 else
1284 {
1285 tmp >>= n;
1286 }
1287
1288 if(is_sat)
1289 {
1290 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
1291 }
1292
1293 *(output_ptr + x) = static_cast<int16_t>(tmp);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001294 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001295 },
1296 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001297}
1298
1299template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001300void mul_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001301{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001302 // Create input windows
1303 Window win = window;
1304 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1305 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001306
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001307 // Clear X Dimension on execution window as we handle manually
1308 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1309 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1310 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001311
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001312 Iterator input1(in1, input1_win);
1313 Iterator input2(in2, input2_win);
1314 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001315
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001316 const int window_step_x = 16;
1317 const auto window_start_x = static_cast<int>(window.x().start());
1318 const auto window_end_x = static_cast<int>(window.x().end());
1319
1320 execute_window_loop(win, [&](const Coordinates &)
1321 {
1322 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
1323 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1324 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
1325
1326 // Compute window_step_x elements per iteration
1327 int x = window_start_x;
1328 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1329 {
1330 const int16x8x2_t ta1 =
1331 {
1332 {
1333 vld1q_s16(input1_ptr + x),
1334 vld1q_s16(input1_ptr + x + 8),
1335 }
1336 };
1337 const uint8x8x2_t ta2u =
1338 {
1339 {
1340 vld1_u8(input2_ptr + x),
1341 vld1_u8(input2_ptr + x + 8),
1342 }
1343 };
1344 const int16x8x2_t ta2 =
1345 {
1346 {
1347 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1348 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1349 }
1350 };
1351
1352 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1353
1354 vst1q_s16(output_ptr + x, result.val[0]);
1355 vst1q_s16(output_ptr + x + 8, result.val[1]);
1356 }
1357
1358 // Compute left-over elements
1359 for(; x < window_end_x; ++x)
1360 {
1361 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1362
1363 if(is_scale255)
1364 {
1365 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1366
1367 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1368 }
1369 else
1370 {
1371 if(tmp >= 0)
1372 {
1373 tmp >>= n;
1374 }
1375 else
1376 {
1377 uint32_t mask = (1u << n) - 1;
1378 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
1379 }
1380 }
1381 if(is_sat)
1382 {
1383 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1384 }
1385 *(output_ptr + x) = static_cast<int16_t>(tmp);
1386 }
1387 },
1388 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001389}
1390
1391template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001392void mul_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001393{
1394 // Simply swap the two input buffers
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001395 mul_S16_U8_S16<is_scale255, is_sat>(in2, in1, out, window, n);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001396}
1397} // namespace
1398
1399NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001400 : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001401{
1402}
1403
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001404void NEPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001405{
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001406 ARM_COMPUTE_UNUSED(rounding_policy);
Georgios Pinitasf0dea702017-07-03 18:17:28 +01001407 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1408
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001409 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001410
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001411 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001412 const TensorShape &out_shape = broadcast_pair.first;
1413 const ValidRegion &valid_region = broadcast_pair.second;
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001414
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001415 // Auto initialize output if not initialized
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001416 set_shape_if_empty(*output, out_shape);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001417
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001418 _scale = scale;
1419 _scale_exponent = 0;
1420 _func_quantized = nullptr;
1421 _func_int = nullptr;
1422 _func_float = nullptr;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001423
1424 bool is_scale_255 = false;
1425 // Check and validate scaling factor
1426 if(std::abs(scale - scale255_constant) < 0.00001f)
1427 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001428 is_scale_255 = true;
1429 }
1430 else
1431 {
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001432 int exponent = 0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001433
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001434 std::frexp(scale, &exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001435
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001436 // Store the positive exponent. We know that we compute 1/2^n
1437 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1438 _scale_exponent = std::abs(exponent - 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001439 }
1440
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001441 const DataType dt_input1 = input1->data_type();
1442 const DataType dt_input2 = input2->data_type();
1443 const DataType dt_output = output->data_type();
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001444 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
1445
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001446 switch(dt_input1)
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001447 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001448 case DataType::QASYMM8:
1449 if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1450 {
1451 _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1452 }
1453 break;
1454 case DataType::QASYMM8_SIGNED:
1455 if(dt_input2 == DataType::QASYMM8_SIGNED)
1456 {
1457 _func_quantized = &mul_saturate_quantized_8<int8_t>;
1458 ;
1459 }
1460 break;
1461 case DataType::QSYMM16:
1462 if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1463 {
1464 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1465 }
1466 else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1467 {
1468 _func_int = &mul_QSYMM16_QSYMM16_S32;
1469 }
1470 break;
1471 case DataType::S16:
1472 if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1473 {
1474 if(is_scale_255)
1475 {
1476 _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1477 }
1478 else
1479 {
1480 _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1481 }
1482 }
1483 if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1484 {
1485 if(is_scale_255)
1486 {
1487 _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1488 }
1489 else
1490 {
1491 _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1492 }
1493 }
1494 break;
SiCong Libb88f892020-08-28 11:18:47 +01001495 case DataType::S32:
1496 if(DataType::S32 == dt_input2 && DataType::S32 == dt_output)
1497 {
1498 _func_int = is_sat ? &mul_S32_S32_S32<true> : &mul_S32_S32_S32<false>;
1499 }
1500 break;
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001501 case DataType::U8:
1502 if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1503 {
1504 if(is_scale_255)
1505 {
1506 _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1507 }
1508 else
1509 {
1510 _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1511 }
1512 }
1513 else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1514 {
1515 if(is_scale_255)
1516 {
1517 _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1518 }
1519 else
1520 {
1521 _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1522 }
1523 }
1524 else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1525 {
1526 if(is_scale_255)
1527 {
1528 _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1529 }
1530 else
1531 {
1532 _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1533 }
1534 }
1535 break;
1536#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1537 case DataType::F16:
1538 _func_float = &mul_F16_F16_F16;
1539 break;
1540#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1541 case DataType::F32:
1542 _func_float = &mul_F32_F32_F32;
1543 break;
1544 default:
1545 ARM_COMPUTE_ERROR("You called with the wrong img formats");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001546 }
1547
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001548 // Configure kernel window
1549 Coordinates coord;
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001550 coord.set_num_dimensions(output->num_dimensions());
1551 output->set_valid_region(valid_region);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001552 Window win = calculate_max_window(valid_region, Steps());
1553
1554 INEKernel::configure(win);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001555}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001556
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001557Status NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
1558 RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001559{
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001560 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001561 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001562
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001563 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001564}
1565
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001566void NEPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001567{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001568 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001569 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1570 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1571
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001572 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1573 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1574 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001575
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001576 if(_func_quantized != nullptr)
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001577 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001578 (*_func_quantized)(input1, input2, output, window, _scale);
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001579 }
1580 else if(_func_int != nullptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001581 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001582 (*_func_int)(input1, input2, output, window, _scale_exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001583 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001584 else
1585 {
1586 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001587 (*_func_float)(input1, input2, output, window, _scale);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001588 }
1589}
giuros01154bc1c2019-03-26 17:44:40 +00001590namespace
1591{
giuros01154bc1c2019-03-26 17:44:40 +00001592Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1593{
1594 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32);
1595 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32);
1596
1597 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
1598
1599 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1600
1601 // Validate in case of configured output
1602 if(output->total_size() > 0)
1603 {
1604 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32);
1605 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
1606 }
1607
1608 return Status{};
1609}
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001610} // namespace
giuros01154bc1c2019-03-26 17:44:40 +00001611
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001612void NEComplexPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
giuros01154bc1c2019-03-26 17:44:40 +00001613{
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001614 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1615 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1, input2, output));
1616
giuros01154bc1c2019-03-26 17:44:40 +00001617 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
1618 const TensorShape &out_shape = broadcast_pair.first;
1619 const ValidRegion &valid_region = broadcast_pair.second;
1620
1621 // Auto initialize output if not initialized
1622 const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type());
1623 auto_init_if_empty(*output, out_info);
1624
giuros01154bc1c2019-03-26 17:44:40 +00001625 // Configure kernel window
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001626 Coordinates coord;
1627 coord.set_num_dimensions(output->num_dimensions());
1628 output->set_valid_region(valid_region);
1629 Window win = calculate_max_window(valid_region, Steps());
giuros01154bc1c2019-03-26 17:44:40 +00001630
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001631 INEKernel::configure(win);
giuros01154bc1c2019-03-26 17:44:40 +00001632}
1633
1634Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1635{
1636 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1637 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output));
giuros01154bc1c2019-03-26 17:44:40 +00001638
1639 return Status{};
1640}
1641
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001642void NEComplexPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
giuros01154bc1c2019-03-26 17:44:40 +00001643{
1644 ARM_COMPUTE_UNUSED(info);
1645 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1646 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1647
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001648 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1649 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1650 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001651
Sheri Zhang4d91dc62020-09-23 11:22:50 +01001652 c_mul_F32_F32_F32_n(input1, input2, output, window);
giuros01154bc1c2019-03-26 17:44:40 +00001653}
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001654} // namespace arm_compute