blob: 302ee7694f4446a2a7d2a0bbee6dd40f4213be41 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2016-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Manuel Bottini79fa9a22019-02-22 17:54:22 +000027#include "arm_compute/core/NEON/NEAsymm.h"
Manuel Bottini7bb56c62019-06-26 15:17:09 +010028#include "arm_compute/core/NEON/NESymm.h"
giuros01154bc1c2019-03-26 17:44:40 +000029#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031
32#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000034#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellodf246182017-07-03 16:25:09 +010035#include <arm_fp16.h> // needed for float16_t
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000036#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +010037
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038namespace arm_compute
39{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040namespace
41{
42const float scale255_constant = 1.f / 255.f;
43const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
44const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
45
Georgios Pinitas631c41a2017-12-06 11:53:03 +000046inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000047{
48 ARM_COMPUTE_UNUSED(overflow_policy);
49 ARM_COMPUTE_UNUSED(rounding_policy);
50
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
SiCong Libb88f892020-08-28 11:18:47 +010052 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
53 DataType::F32);
54 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
55 DataType::F32);
Michele Di Giorgio9428a182020-03-30 14:10:20 +010056 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
57 DataType::S16, DataType::QSYMM16,
58 DataType::S32, DataType::F16, DataType::F32);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000059 if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
Pablo Tello52ea9c22019-12-10 11:28:53 +000060 {
61 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
Georgios Pinitasd7d7e902019-12-18 15:40:54 +000062 ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
Pablo Tello52ea9c22019-12-10 11:28:53 +000063 }
Manuel Bottini79fa9a22019-02-22 17:54:22 +000064
65 if(output->total_size() > 0)
66 {
Manuel Bottini79fa9a22019-02-22 17:54:22 +000067 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
68 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
69 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
SiCong Libb88f892020-08-28 11:18:47 +010070 // clang-format off
71 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
72 !(input1->data_type() == input2->data_type() && input2->data_type() == output->data_type()) &&
73 !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
74 !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) &&
75 !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
76 !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
77 !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::S32)
78 , "Invalid data type combination");
79 // clang-format on
80 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S16 && output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
Manuel Bottini79fa9a22019-02-22 17:54:22 +000081 }
Michalis Spyrou861f0db2018-02-26 16:47:58 +000082
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000083 if(std::abs(scale - scale255_constant) < 0.00001f)
84 {
85 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
SiCong Libb88f892020-08-28 11:18:47 +010086 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32,
87 "Scale == 1/255 is not supported if input and output are of data type S32");
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000088 }
89 else
90 {
91 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
92
93 int exponent = 0;
94 const float normalized_mantissa = std::frexp(scale, &exponent);
95
96 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
97 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
98 // Moreover, it will be negative as we deal with 1/2^n
99 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
100 }
101
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000102 return Status{};
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000103}
104
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105/* Scales a given vector by 1/255.
106 *
107 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
108 *
109 * @param in Input vector to scale.
110 * @return Scaled output rounded to nearest (round half up).
111 */
112inline int32x4_t scale255_S32_S32(int32x4_t in)
113{
114 // Scale
115 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
116 // Round to nearest (round half up)
117 // Add +0.5 for all values
118 // Afterwards vcvt rounds toward zero
119 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
120}
121
122inline uint16x8_t scale255_U16_U16(uint16x8_t in)
123{
124 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
125 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
126 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
127}
128
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100129template <typename T>
130inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
131vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000132{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100133 return vquantize_signed(val, info);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000134}
135
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100136template <typename T>
137inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
138vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
Pablo Tello52ea9c22019-12-10 11:28:53 +0000139{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100140 return vquantize(val, info);
Pablo Tello52ea9c22019-12-10 11:28:53 +0000141}
142
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100143template <typename T>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100144void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
145{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100146 // Create input windows
147 Window win = window;
148 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
149 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
150
151 // Clear X Dimension on execution window as we handle manually
152 win.set(Window::DimX, Window::Dimension(0, 1, 1));
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100153
Sheri Zhanga449a362020-07-16 15:52:25 +0100154 const int window_step_x = 16 / sizeof(T);
155 const auto window_start_x = static_cast<int>(window.x().start());
156 const auto window_end_x = static_cast<int>(window.x().end());
157 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100158
Sheri Zhanga449a362020-07-16 15:52:25 +0100159 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
160 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100161
Sheri Zhanga449a362020-07-16 15:52:25 +0100162 if(is_broadcast_across_x)
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100163 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100164 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
165 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
166 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
167 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
168 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
169 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
170 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100171
Sheri Zhanga449a362020-07-16 15:52:25 +0100172 // Clear X Dimension on execution window as we handle manually
173 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
174
175 Iterator broadcast_input(broadcast_tensor, broadcast_win);
176 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
177 Iterator output(out, win);
178
179 using ExactTagType = typename wrapper::traits::neon_vector<T, window_step_x>::tag_type;
180
181 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100182 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100183 const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
184 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100185
Sheri Zhanga449a362020-07-16 15:52:25 +0100186 const auto broadcast_value = *reinterpret_cast<const T *>(broadcast_input.ptr());
187 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100188
Sheri Zhanga449a362020-07-16 15:52:25 +0100189 // Compute window_step_x elements per iteration
190 int x = window_start_x;
191 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100192 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100193 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100194
Sheri Zhanga449a362020-07-16 15:52:25 +0100195 // Dequantize inputs
196 const float32x4x4_t in1_f32x4x4 = vdequantize(non_broadcast_v, non_broadcast_qinfo);
197 const float32x4x4_t in2_f32x4x4 = vdequantize(broadcast_value_vec, broadcast_qinfo);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100198
Sheri Zhanga449a362020-07-16 15:52:25 +0100199 const float32x4x4_t out_f32x4x4 =
200 {
201 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
202 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
203 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
204 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
205 };
206
207 // Quantize output
208 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
209 wrapper::vstore(output_ptr + x, result);
210 }
211
212 // Compute left-over elements
213 for(; x < window_end_x; ++x)
214 {
215 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100216 const T in1 = *(non_broadcast_input_ptr + x);
217 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, non_broadcast_qinfo);
218 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(broadcast_value, broadcast_qinfo);
219 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100220
221 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100222 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100223 *(output_ptr + x) = tmp_qua;
224 }
225 },
226 broadcast_input, non_broadcast_input, output);
227 }
228 else
229 {
230 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
231 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
232
233 // Clear X Dimension on execution window as we handle manually
234 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
235 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
236
237 Iterator input1(in1, input1_win);
238 Iterator input2(in2, input2_win);
239 Iterator output(out, win);
240
241 execute_window_loop(win, [&](const Coordinates &)
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100242 {
Sheri Zhanga449a362020-07-16 15:52:25 +0100243 const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
244 const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
245 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100246
Sheri Zhanga449a362020-07-16 15:52:25 +0100247 // Compute window_step_x elements per iteration
248 int x = window_start_x;
249 for(; x <= (window_end_x - window_step_x); x += window_step_x)
250 {
251 const auto input1_q = wrapper::vloadq(input1_ptr + x);
252 const auto input2_q = wrapper::vloadq(input2_ptr + x);
253
254 // Dequantize inputs
255 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
256 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
257
258 const float32x4x4_t out_f32x4x4 =
259 {
260 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
261 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
262 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
263 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
264 };
265
266 // Quantize output
267 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
268 wrapper::vstore(output_ptr + x, result);
269 }
270
271 // Compute left-over elements
272 for(; x < window_end_x; ++x)
273 {
274 // Dequantize inputs
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100275 const T in1 = *(input1_ptr + x);
276 const T in2 = *(input2_ptr + x);
277 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, input1_qua_info);
278 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(in2, input2_qua_info);
279 const float tmp_f = tmp_in1 * tmp_in2;
Sheri Zhanga449a362020-07-16 15:52:25 +0100280
281 // Quantize output
Michele Di Giorgio40aad9b2020-07-22 15:17:43 +0100282 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
Sheri Zhanga449a362020-07-16 15:52:25 +0100283 *(output_ptr + x) = tmp_qua;
284 }
285 },
286 input1, input2, output);
287 }
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100288}
289
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100290void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
291{
292 const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
293 const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
294 const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
295
296 // Create input windows
297 Window win = window;
298 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
299 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
300
301 // Clear X Dimension on execution window as we handle manually
302 win.set(Window::DimX, Window::Dimension(0, 1, 1));
303 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
304 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
305
306 Iterator input1(in1, input1_win);
307 Iterator input2(in2, input2_win);
308 Iterator output(out, win);
309
310 const int window_step_x = 16;
311 const auto window_start_x = static_cast<int>(window.x().start());
312 const auto window_end_x = static_cast<int>(window.x().end());
313
314 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
315
316 execute_window_loop(win, [&](const Coordinates &)
317 {
318 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
319 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
320 const auto output_ptr = reinterpret_cast<qsymm16_t *>(output.ptr());
321
322 // Compute window_step_x elements per iteration
323 int x = window_start_x;
324 for(; x <= (window_end_x - window_step_x); x += window_step_x)
325 {
326 const qsymm16x8x2_t input1_q =
327 {
328 {
329 vld1q_s16(input1_ptr + x),
330 vld1q_s16(input1_ptr + x + 8),
331 }
332 };
333 const qsymm16x8x2_t input2_q =
334 {
335 {
336 vld1q_s16(input2_ptr + x),
337 vld1q_s16(input2_ptr + x + 8),
338 }
339 };
340
341 // Dequantize inputs
342 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
343 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
344
345 const float32x4x4_t out_f32x4x4 =
346 {
347 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
348 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
349 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
350 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
351 };
352
353 const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
354 vst1q_s16(output_ptr + x, result.val[0]);
355 vst1q_s16(output_ptr + x + 8, result.val[1]);
356 }
357
358 // Compute left-over elements
359 for(; x < window_end_x; ++x)
360 {
361 // Dequantize inputs
362 float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
363 float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
364 float tmp_f = tmp_in1 * tmp_in2;
365
366 // Quantize output, lrintf() has same rounding mode as vcombine_s16
367 int32_t tmp = lrintf(tmp_f / tmp_qua_info.scale);
368 qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
369 *(output_ptr + x) = tmp_qua;
370 }
371 },
372 input1, input2, output);
373}
374
375void mul_QSYMM16_QSYMM16_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int scale)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100376{
377 ARM_COMPUTE_UNUSED(scale);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100378
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100379 // Create input windows
380 Window win = window;
381 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
382 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100383
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100384 // Clear X Dimension on execution window as we handle manually
385 win.set(Window::DimX, Window::Dimension(0, 1, 1));
386 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
387 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100388
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100389 Iterator input1(in1, input1_win);
390 Iterator input2(in2, input2_win);
391 Iterator output(out, win);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100392
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100393 const int window_step_x = 16;
394 const auto window_start_x = static_cast<int>(window.x().start());
395 const auto window_end_x = static_cast<int>(window.x().end());
396
397 execute_window_loop(win, [&](const Coordinates &)
398 {
399 const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
400 const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
401 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
402
403 // Compute window_step_x elements per iteration
404 int x = window_start_x;
405 for(; x <= (window_end_x - window_step_x); x += window_step_x)
406 {
407 const qsymm16x8x2_t input1_q =
408 {
409 {
410 vld1q_s16(input1_ptr + x),
411 vld1q_s16(input1_ptr + x + 8),
412 }
413 };
414 const qsymm16x8x2_t input2_q =
415 {
416 {
417 vld1q_s16(input2_ptr + x),
418 vld1q_s16(input2_ptr + x + 8),
419 }
420 };
421
422 const int32x4x4_t in1_s32 =
423 {
424 {
425 vmovl_s16(vget_low_s16(input1_q.val[0])),
426 vmovl_s16(vget_high_s16(input1_q.val[0])),
427 vmovl_s16(vget_low_s16(input1_q.val[1])),
428 vmovl_s16(vget_high_s16(input1_q.val[1])),
429 }
430 };
431 const int32x4x4_t in2_s32 =
432 {
433 {
434 vmovl_s16(vget_low_s16(input2_q.val[0])),
435 vmovl_s16(vget_high_s16(input2_q.val[0])),
436 vmovl_s16(vget_low_s16(input2_q.val[1])),
437 vmovl_s16(vget_high_s16(input2_q.val[1])),
438 }
439 };
440
441 const int32x4x4_t result =
442 {
443 {
444 vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
445 vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
446 vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
447 vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
448 }
449 };
450
451 vst1q_s32(output_ptr + x, result.val[0]);
452 vst1q_s32(output_ptr + x + 4, result.val[1]);
453 vst1q_s32(output_ptr + x + 8, result.val[2]);
454 vst1q_s32(output_ptr + x + 12, result.val[3]);
455 }
456
457 // Compute left-over elements
458 for(; x < window_end_x; ++x)
459 {
460 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
461 *(output_ptr + x) = tmp;
462 }
463 },
464 input1, input2, output);
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100465}
466
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100467template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100468void mul_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100470 // Create input windows
471 Window win = window;
472 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
473 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100474
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100475 // Clear X Dimension on execution window as we handle manually
476 win.set(Window::DimX, Window::Dimension(0, 1, 1));
477 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
478 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100479
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100480 Iterator input1(in1, input1_win);
481 Iterator input2(in2, input2_win);
482 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100484 const int window_step_x = 16 / sizeof(uint8_t);
485 const auto window_start_x = static_cast<int>(window.x().start());
486 const auto window_end_x = static_cast<int>(window.x().end());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100487
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100488 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100489 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100490 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
491 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
492 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100493
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100494 // Compute window_step_x elements per iteration
495 int x = window_start_x;
496 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100497 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100498 const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
499 const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100501 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
502 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
503 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
504 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
505
506 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
507 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
508
509 if(is_scale255)
510 {
511 tmp1_high = scale255_U16_U16(tmp1_high);
512 tmp1_low = scale255_U16_U16(tmp1_low);
513 }
514 else
515 {
516 const int16x8_t vn = vdupq_n_s16(-n);
517
518 if(is_sat)
519 {
520 tmp1_high = vqshlq_u16(tmp1_high, vn);
521 tmp1_low = vqshlq_u16(tmp1_low, vn);
522 }
523 else
524 {
525 tmp1_high = vshlq_u16(tmp1_high, vn);
526 tmp1_low = vshlq_u16(tmp1_low, vn);
527 }
528 }
529 if(is_sat)
530 {
531 vst1q_u8(output_ptr, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
532 }
533 else
534 {
535 vst1q_u8(output_ptr, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
536 }
537 }
538
539 // Compute left-over elements
540 for(; x < window_end_x; ++x)
541 {
542 uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
543
544 if(is_scale255)
545 {
546 float tmp_f = static_cast<float>(tmp) * scale255_constant;
547 tmp = static_cast<uint16_t>(tmp_f + 0.5f);
548 }
549 else
550 {
551 tmp >>= n;
552 }
553 if(is_sat && tmp > 255)
554 {
555 tmp = 255;
556 }
557 *(output_ptr + x) = static_cast<uint8_t>(tmp);
558 }
559 },
560 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100561}
562
563template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100564inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
565{
566 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
567 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
568 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
569 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
570
571 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
572 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
573
574 if(is_scale255)
575 {
576 tmp1_high = scale255_S32_S32(tmp1_high);
577 tmp1_low = scale255_S32_S32(tmp1_low);
578 }
579 else
580 {
581 // Right shift amount
582 const int32x4_t vn = vdupq_n_s32(-n);
583 // Left shift amount
584 const int32x4_t vnl = vdupq_n_s32(n);
585 // Calculate conversion bit
586 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
587 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
588 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
589 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
590 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
591 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
592 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
593 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
594 if(is_sat)
595 {
596 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
597 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
598 }
599 else
600 {
601 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
602 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
603 }
604 }
605
606 if(is_sat)
607 {
608 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
609 }
610 else
611 {
612 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
613 }
614}
615
616template <bool is_scale255, bool is_sat>
617inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
618{
619 const int16x8x2_t result =
620 {
621 {
622 // First 8 elements
623 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
624 // Second 8 elements
625 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
626 }
627 };
628
629 return result;
630}
631
632template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100633void mul_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100635 // Create input windows
636 Window win = window;
637 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
638 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100639
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100640 // Clear X Dimension on execution window as we handle manually
641 win.set(Window::DimX, Window::Dimension(0, 1, 1));
642 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
643 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100644
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100645 Iterator input1(in1, input1_win);
646 Iterator input2(in2, input2_win);
647 Iterator output(out, win);
648
649 const int window_step_x = 16;
650 const auto window_start_x = static_cast<int>(window.x().start());
651 const auto window_end_x = static_cast<int>(window.x().end());
652
653 execute_window_loop(win, [&](const Coordinates &)
654 {
655 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
656 const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
657 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
658
659 // Compute window_step_x elements per iteration
660 int x = window_start_x;
661 for(; x <= (window_end_x - window_step_x); x += window_step_x)
662 {
663 const int16x8x2_t ta1 =
664 {
665 {
666 vld1q_s16(input1_ptr + x),
667 vld1q_s16(input1_ptr + x + 8),
668 }
669 };
670 const int16x8x2_t ta2 =
671 {
672 {
673 vld1q_s16(input2_ptr + x),
674 vld1q_s16(input2_ptr + x + 8),
675 }
676 };
677 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
678
679 vst1q_s16(output_ptr + x, result.val[0]);
680 vst1q_s16(output_ptr + x + 8, result.val[1]);
681 }
682
683 // Compute left-over elements
684 for(; x < window_end_x; ++x)
685 {
686 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
687
688 if(is_scale255)
689 {
690 float tmp_f = static_cast<float>(tmp) * scale255_constant;
691
692 tmp = static_cast<int32_t>(tmp_f + 0.5f);
693 }
694 else
695 {
696 if(tmp >= 0)
697 {
698 tmp >>= n;
699 }
700 else
701 {
702 uint32_t mask = (1u << n) - 1;
703 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
704 }
705 }
706 if(is_sat)
707 {
708 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
709 }
710 *(output_ptr + x) = static_cast<int16_t>(tmp);
711 }
712 },
713 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100714}
715
SiCong Libb88f892020-08-28 11:18:47 +0100716template <bool is_sat>
717inline int32x4_t mul_S32_S32_S32_n_loop(const int32x4_t &input1, const int32x4_t &input2, int n)
718{
719 const int32x2_t input1_1 = vget_low_s32(input1);
720 const int32x2_t input2_1 = vget_low_s32(input2);
721 const int32x2_t input1_2 = vget_high_s32(input1);
722 const int32x2_t input2_2 = vget_high_s32(input2);
723
724 int64x2_t tmp_1 = vmull_s32(input1_1, input2_1);
725 int64x2_t tmp_2 = vmull_s32(input1_2, input2_2);
726
727 // Apply scaling, conversion and rounding (round to zero)
728 // Right shift amount
729 const int64x2_t vn = vdupq_n_s64(-n);
730 // Left shift amount
731 const int64x2_t vnl = vdupq_n_s64(n);
732 // Calculate conversion bit
733 const uint64x2_t tmp_1_u = vreinterpretq_u64_s64(tmp_1);
734 const uint64x2_t sign_1 = vshrq_n_u64(tmp_1_u, 63);
735 const int64x2_t sign_1_s = vreinterpretq_s64_u64(sign_1);
736 const int64x2_t convert_1 = vsubq_s64(vshlq_s64(sign_1_s, vnl), sign_1_s);
737
738 const uint64x2_t tmp_2_u = vreinterpretq_u64_s64(tmp_2);
739 const uint64x2_t sign_2 = vshrq_n_u64(tmp_2_u, 63);
740 const int64x2_t sign_2_s = vreinterpretq_s64_u64(sign_2);
741 const int64x2_t convert_2 = vsubq_s64(vshlq_s64(sign_2_s, vnl), sign_2_s);
742 if(is_sat)
743 {
744 tmp_1 = vqshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
745 tmp_2 = vqshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
746 return vcombine_s32(vqmovn_s64(tmp_1), vqmovn_s64(tmp_2));
747 }
748 else
749 {
750 tmp_1 = vshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
751 tmp_2 = vshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
752 return vcombine_s32(vmovn_s64(tmp_1), vmovn_s64(tmp_2));
753 }
754}
755
756template <bool is_sat>
757inline int32x4x2_t mul_S32_S32_S32_n_k(const int32x4x2_t &input1, const int32x4x2_t &input2, int n)
758{
759 const int32x4x2_t result =
760 {
761 {
762 // First 4 elements
763 mul_S32_S32_S32_n_loop<is_sat>(input1.val[0], input2.val[0], n),
764 // Second 4 elements
765 mul_S32_S32_S32_n_loop<is_sat>(input1.val[1], input2.val[1], n)
766 }
767 };
768
769 return result;
770}
771
772template <bool is_sat>
773void mul_S32_S32_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
774{
775 // Create input windows
776 Window win = window;
777 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
778 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
779
780 // Clear X Dimension on execution window as we handle manually
781 win.set(Window::DimX, Window::Dimension(0, 1, 1));
782 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
783 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
784
785 Iterator input1(in1, input1_win);
786 Iterator input2(in2, input2_win);
787 Iterator output(out, win);
788
789 const int window_step_x = 8;
790 const auto window_start_x = static_cast<int>(window.x().start());
791 const auto window_end_x = static_cast<int>(window.x().end());
792
793 execute_window_loop(win, [&](const Coordinates &)
794 {
795 const auto input1_ptr = reinterpret_cast<const int32_t *>(input1.ptr());
796 const auto input2_ptr = reinterpret_cast<const int32_t *>(input2.ptr());
797 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
798
799 // Compute window_step_x elements per iteration
800 int x = window_start_x;
801 for(; x <= (window_end_x - window_step_x); x += window_step_x)
802 {
803 const int32x4x2_t ta1 =
804 {
805 {
806 vld1q_s32(input1_ptr + x),
807 vld1q_s32(input1_ptr + x + 4),
808 }
809 };
810 const int32x4x2_t ta2 =
811 {
812 {
813 vld1q_s32(input2_ptr + x),
814 vld1q_s32(input2_ptr + x + 4),
815 }
816 };
817 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(ta1, ta2, n);
818
819 vst1q_s32(output_ptr + x, result.val[0]);
820 vst1q_s32(output_ptr + x + 4, result.val[1]);
821 }
822
823 // Compute left-over elements
824 for(; x < window_end_x; ++x)
825 {
826 int64_t tmp = static_cast<int64_t>(*(input1_ptr + x)) * static_cast<int64_t>(*(input2_ptr + x));
827
828 if(tmp >= 0)
829 {
830 tmp >>= n;
831 }
832 else
833 {
834 uint64_t mask = (1u << n) - 1;
835 tmp = (tmp + static_cast<int64_t>(mask)) >> n;
836 }
837 if(is_sat)
838 {
839 tmp = (tmp > INT_MAX) ? INT_MAX : ((tmp < INT_MIN) ? INT_MIN : tmp);
840 }
841 *(output_ptr + x) = static_cast<int32_t>(tmp);
842 }
843 },
844 input1, input2, output);
845}
846
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100847void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100848{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100849 // Create input windows
850 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
851 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100852
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100853 // Clear X Dimension on execution window as we handle manually
854 Window win = window;
855 win.set(Window::DimX, Window::Dimension(0, 1, 1));
856
857 constexpr int window_step_x = 16 / sizeof(float);
858 const auto window_start_x = static_cast<int>(window.x().start());
859 const auto window_end_x = static_cast<int>(window.x().end());
860 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
861
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100862 using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
863
864 if(is_broadcast_across_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100865 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100866 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
867 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
868 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
869 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
870 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
871
872 // Clear X Dimension on execution window as we handle manually
873 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
874
875 Iterator broadcast_input(broadcast_tensor, broadcast_win);
876 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
877 Iterator output(out, win);
878
879 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100880 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100881 const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
882 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
883
884 const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
885 const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
886 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
887
888 // Compute window_step_x elements per iteration
889 int x = window_start_x;
890 for(; x <= (window_end_x - window_step_x); x += window_step_x)
891 {
892 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
893 auto res = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
894 wrapper::vstore(output_ptr + x, res);
895 }
896
897 // Compute left-over elements
898 for(; x < window_end_x; ++x)
899 {
900 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
901 *(output_ptr + x) = broadcast_value * non_broadcast_v * scale;
902 }
903 },
904 broadcast_input, non_broadcast_input, output);
905 }
906 else
907 {
908 // Clear X Dimension on execution window as we handle manually
909 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
910 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
911
912 Iterator input1(in1, input1_win);
913 Iterator input2(in2, input2_win);
914 Iterator output(out, win);
915
916 execute_window_loop(win, [&](const Coordinates &)
917 {
918 const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
919 const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
920 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
921
922 // Compute window_step_x elements per iteration
923 int x = window_start_x;
924 for(; x <= (window_end_x - window_step_x); x += window_step_x)
925 {
926 const auto ta1 = wrapper::vloadq(input1_ptr + x);
927 const auto ta2 = wrapper::vloadq(input2_ptr + x);
928 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
929 const auto res = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
930 wrapper::vstore(output_ptr + x, res);
931 }
932
933 // Compute left-over elements
934 for(; x < window_end_x; ++x)
935 {
936 const auto ta1 = *(input1_ptr + x);
937 const auto ta2 = *(input2_ptr + x);
938 *(output_ptr + x) = ta1 * ta2 * scale;
939 }
940 },
941 input1, input2, output);
942 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100943}
944
giuros01154bc1c2019-03-26 17:44:40 +0000945void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr)
946{
947 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
948 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
949 const auto output = static_cast<float *__restrict>(output_ptr);
950
951 const float32x4_t a = wrapper::vloadq(input1);
952 float32x4_t b = wrapper::vloadq(input2);
953
954 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
955
956 const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
957 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
958 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
959 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
960 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
961
962 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
963 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
964
965 float32x4_t res = wrapper::vmul(tmp0, b);
966
967 b = wrapper::vrev64(b);
968 b = wrapper::vmul(b, mask);
969
970 res = wrapper::vmla(res, tmp1, b);
971 wrapper::vstore(output, res);
972}
973
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000974#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100975void mul_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
976{
977 // Create input windows
978 Window win = window;
979 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
980 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
981
982 // Clear X Dimension on execution window as we handle manually
983 win.set(Window::DimX, Window::Dimension(0, 1, 1));
984 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
985 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
986
987 Iterator input1(in1, input1_win);
988 Iterator input2(in2, input2_win);
989 Iterator output(out, win);
990
991 const int window_step_x = 16;
992 const auto window_start_x = static_cast<int>(window.x().start());
993 const auto window_end_x = static_cast<int>(window.x().end());
994
995 execute_window_loop(win, [&](const Coordinates &)
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100996 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100997 const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
998 const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
999 const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
1000
1001 // Compute window_step_x elements per iteration
1002 int x = window_start_x;
1003 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001004 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001005 const float16x8x2_t ta1 =
1006 {
1007 {
1008 vld1q_f16(input1_ptr + x),
1009 vld1q_f16(input1_ptr + x + 8),
1010 }
1011 };
1012 const float16x8x2_t ta2 =
1013 {
1014 {
1015 vld1q_f16(input2_ptr + x),
1016 vld1q_f16(input2_ptr + x + 8),
1017 }
1018 };
1019 const float16x8_t scale_vec = vdupq_n_f16(scale);
1020 const float16x8x2_t result =
1021 {
1022 {
1023 vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
1024 vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
1025 }
1026 };
1027 vst1q_f16(output_ptr + x, result.val[0]);
1028 vst1q_f16(output_ptr + x + 8, result.val[1]);
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001029 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001030
1031 // Compute left-over elements
1032 for(; x < window_end_x; ++x)
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001033 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001034 const auto ta1 = *(input1_ptr + x);
1035 const auto ta2 = *(input2_ptr + x);
1036 *(output_ptr + x) = ta1 * ta2 * scale;
Michele Di Giorgio9428a182020-03-30 14:10:20 +01001037 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001038 },
1039 input1, input2, output);
1040}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001041#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +01001042
1043template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001044void mul_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001045{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001046 // Create input windows
1047 Window win = window;
1048 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1049 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001050
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001051 // Clear X Dimension on execution window as we handle manually
1052 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1053 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1054 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001055
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001056 Iterator input1(in1, input1_win);
1057 Iterator input2(in2, input2_win);
1058 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001059
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001060 const int window_step_x = 16 / sizeof(uint8_t);
1061 const auto window_start_x = static_cast<int>(window.x().start());
1062 const auto window_end_x = static_cast<int>(window.x().end());
1063
1064 execute_window_loop(win, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001065 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001066 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
1067 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1068 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001069
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001070 // Compute window_step_x elements per iteration
1071 int x = window_start_x;
1072 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001073 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001074 const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
1075 const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
1076
1077 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
1078 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
1079 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
1080 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
1081
1082 if(is_scale255)
1083 {
1084 tmp_low = scale255_U16_U16(tmp_low);
1085 tmp_high = scale255_U16_U16(tmp_high);
1086 }
1087 else
1088 {
1089 const int16x8_t vn = vdupq_n_s16(-n);
1090
1091 if(is_sat)
1092 {
1093 tmp_low = vqshlq_u16(tmp_low, vn);
1094 tmp_high = vqshlq_u16(tmp_high, vn);
1095 }
1096 else
1097 {
1098 tmp_low = vshlq_u16(tmp_low, vn);
1099 tmp_high = vshlq_u16(tmp_high, vn);
1100 }
1101 }
1102
1103 if(is_sat)
1104 {
1105 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
1106
1107 tmp_low = vminq_u16(tmp_low, max);
1108 tmp_high = vminq_u16(tmp_high, max);
1109 }
1110
1111 vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
1112 vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001113 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001114
1115 // Compute left-over elements
1116 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001117 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001118 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1119
1120 if(is_scale255)
1121 {
1122 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1123 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1124 }
1125 else
1126 {
1127 tmp >>= n;
1128 }
1129
1130 if(is_sat)
1131 {
1132 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
1133 }
1134
1135 *(output_ptr + x) = static_cast<int16_t>(tmp);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001136 }
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001137 },
1138 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001139}
1140
1141template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001142void mul_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001143{
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001144 // Create input windows
1145 Window win = window;
1146 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1147 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001148
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001149 // Clear X Dimension on execution window as we handle manually
1150 win.set(Window::DimX, Window::Dimension(0, 1, 1));
1151 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1152 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001153
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001154 Iterator input1(in1, input1_win);
1155 Iterator input2(in2, input2_win);
1156 Iterator output(out, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001157
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001158 const int window_step_x = 16;
1159 const auto window_start_x = static_cast<int>(window.x().start());
1160 const auto window_end_x = static_cast<int>(window.x().end());
1161
1162 execute_window_loop(win, [&](const Coordinates &)
1163 {
1164 const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
1165 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1166 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
1167
1168 // Compute window_step_x elements per iteration
1169 int x = window_start_x;
1170 for(; x <= (window_end_x - window_step_x); x += window_step_x)
1171 {
1172 const int16x8x2_t ta1 =
1173 {
1174 {
1175 vld1q_s16(input1_ptr + x),
1176 vld1q_s16(input1_ptr + x + 8),
1177 }
1178 };
1179 const uint8x8x2_t ta2u =
1180 {
1181 {
1182 vld1_u8(input2_ptr + x),
1183 vld1_u8(input2_ptr + x + 8),
1184 }
1185 };
1186 const int16x8x2_t ta2 =
1187 {
1188 {
1189 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1190 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1191 }
1192 };
1193
1194 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1195
1196 vst1q_s16(output_ptr + x, result.val[0]);
1197 vst1q_s16(output_ptr + x + 8, result.val[1]);
1198 }
1199
1200 // Compute left-over elements
1201 for(; x < window_end_x; ++x)
1202 {
1203 int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1204
1205 if(is_scale255)
1206 {
1207 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1208
1209 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1210 }
1211 else
1212 {
1213 if(tmp >= 0)
1214 {
1215 tmp >>= n;
1216 }
1217 else
1218 {
1219 uint32_t mask = (1u << n) - 1;
1220 tmp = (tmp + static_cast<int32_t>(mask)) >> n;
1221 }
1222 }
1223 if(is_sat)
1224 {
1225 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1226 }
1227 *(output_ptr + x) = static_cast<int16_t>(tmp);
1228 }
1229 },
1230 input1, input2, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001231}
1232
1233template <bool is_scale255, bool is_sat>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001234void mul_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001235{
1236 // Simply swap the two input buffers
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001237 mul_S16_U8_S16<is_scale255, is_sat>(in2, in1, out, window, n);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001238}
1239} // namespace
1240
1241NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001242 : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001243{
1244}
1245
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001246void NEPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001247{
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001248 ARM_COMPUTE_UNUSED(rounding_policy);
Georgios Pinitasf0dea702017-07-03 18:17:28 +01001249 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1250
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001251 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001252
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001253 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001254 const TensorShape &out_shape = broadcast_pair.first;
1255 const ValidRegion &valid_region = broadcast_pair.second;
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001256
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001257 // Auto initialize output if not initialized
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001258 set_shape_if_empty(*output, out_shape);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001259
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001260 _scale = scale;
1261 _scale_exponent = 0;
1262 _func_quantized = nullptr;
1263 _func_int = nullptr;
1264 _func_float = nullptr;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001265
1266 bool is_scale_255 = false;
1267 // Check and validate scaling factor
1268 if(std::abs(scale - scale255_constant) < 0.00001f)
1269 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001270 is_scale_255 = true;
1271 }
1272 else
1273 {
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001274 int exponent = 0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001275
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001276 std::frexp(scale, &exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001277
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001278 // Store the positive exponent. We know that we compute 1/2^n
1279 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1280 _scale_exponent = std::abs(exponent - 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001281 }
1282
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001283 const DataType dt_input1 = input1->data_type();
1284 const DataType dt_input2 = input2->data_type();
1285 const DataType dt_output = output->data_type();
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001286 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
1287
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001288 switch(dt_input1)
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001289 {
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001290 case DataType::QASYMM8:
1291 if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1292 {
1293 _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1294 }
1295 break;
1296 case DataType::QASYMM8_SIGNED:
1297 if(dt_input2 == DataType::QASYMM8_SIGNED)
1298 {
1299 _func_quantized = &mul_saturate_quantized_8<int8_t>;
1300 ;
1301 }
1302 break;
1303 case DataType::QSYMM16:
1304 if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1305 {
1306 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1307 }
1308 else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1309 {
1310 _func_int = &mul_QSYMM16_QSYMM16_S32;
1311 }
1312 break;
1313 case DataType::S16:
1314 if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1315 {
1316 if(is_scale_255)
1317 {
1318 _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1319 }
1320 else
1321 {
1322 _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1323 }
1324 }
1325 if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1326 {
1327 if(is_scale_255)
1328 {
1329 _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1330 }
1331 else
1332 {
1333 _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1334 }
1335 }
1336 break;
SiCong Libb88f892020-08-28 11:18:47 +01001337 case DataType::S32:
1338 if(DataType::S32 == dt_input2 && DataType::S32 == dt_output)
1339 {
1340 _func_int = is_sat ? &mul_S32_S32_S32<true> : &mul_S32_S32_S32<false>;
1341 }
1342 break;
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001343 case DataType::U8:
1344 if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1345 {
1346 if(is_scale_255)
1347 {
1348 _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1349 }
1350 else
1351 {
1352 _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1353 }
1354 }
1355 else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1356 {
1357 if(is_scale_255)
1358 {
1359 _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1360 }
1361 else
1362 {
1363 _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1364 }
1365 }
1366 else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1367 {
1368 if(is_scale_255)
1369 {
1370 _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1371 }
1372 else
1373 {
1374 _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1375 }
1376 }
1377 break;
1378#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1379 case DataType::F16:
1380 _func_float = &mul_F16_F16_F16;
1381 break;
1382#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1383 case DataType::F32:
1384 _func_float = &mul_F32_F32_F32;
1385 break;
1386 default:
1387 ARM_COMPUTE_ERROR("You called with the wrong img formats");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001388 }
1389
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001390 // Configure kernel window
1391 Coordinates coord;
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001392 coord.set_num_dimensions(output->num_dimensions());
1393 output->set_valid_region(valid_region);
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001394 Window win = calculate_max_window(valid_region, Steps());
1395
1396 INEKernel::configure(win);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001397}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001398
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001399Status NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
1400 RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001401{
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001402 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +00001403 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001404
Georgios Pinitas631c41a2017-12-06 11:53:03 +00001405 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001406}
1407
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001408void NEPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001409{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001410 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001411 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1412 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1413
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001414 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1415 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1416 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001417
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +01001418 if(_func_quantized != nullptr)
Michalis Spyrou861f0db2018-02-26 16:47:58 +00001419 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001420 (*_func_quantized)(input1, input2, output, window, _scale);
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001421 }
1422 else if(_func_int != nullptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001423 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001424 (*_func_int)(input1, input2, output, window, _scale_exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001425 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001426 else
1427 {
1428 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001429 (*_func_float)(input1, input2, output, window, _scale);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001430 }
1431}
giuros01154bc1c2019-03-26 17:44:40 +00001432namespace
1433{
1434constexpr unsigned int num_elems_processed_per_iteration_complex = 2;
1435
1436Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1437{
1438 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32);
1439 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32);
1440
1441 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
1442
1443 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1444
1445 // Validate in case of configured output
1446 if(output->total_size() > 0)
1447 {
1448 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32);
1449 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
1450 }
1451
1452 return Status{};
1453}
1454
1455std::pair<Status, Window> validate_and_configure_window_complex(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
1456{
1457 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
1458 const TensorShape &out_shape = broadcast_pair.first;
1459 const ValidRegion &valid_region = broadcast_pair.second;
1460
1461 // Auto initialize output if not initialized
1462 const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type());
1463 auto_init_if_empty(*output, out_info);
1464
1465 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration_complex));
1466 Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
1467 Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
1468
1469 AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration_complex);
1470 AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_complex);
1471 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_complex);
1472
1473 bool window_changed = update_window_and_padding(win_input1, input1_access)
1474 || update_window_and_padding(win_input2, input2_access)
1475 || update_window_and_padding(win, output_access);
1476
1477 output_access.set_valid_region(win, valid_region);
1478
1479 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1480 return std::make_pair(err, win);
1481}
1482} // namespace
1483
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001484void NEComplexPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
giuros01154bc1c2019-03-26 17:44:40 +00001485{
1486 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001487 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1, input2, output));
giuros01154bc1c2019-03-26 17:44:40 +00001488
1489 // Configure kernel window
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001490 auto win_config = validate_and_configure_window_complex(input1, input2, output);
giuros01154bc1c2019-03-26 17:44:40 +00001491 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1492
giuros01154bc1c2019-03-26 17:44:40 +00001493 // Create kernel
1494 INEKernel::configure(win_config.second);
1495}
1496
1497Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1498{
1499 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1500 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output));
1501 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_complex(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
1502
1503 return Status{};
1504}
1505
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001506void NEComplexPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
giuros01154bc1c2019-03-26 17:44:40 +00001507{
1508 ARM_COMPUTE_UNUSED(info);
1509 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1510 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1511
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001512 auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1513 auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1514 auto output = tensors.get_tensor(TensorType::ACL_DST);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001515
1516 Iterator input1_it(input1, window.broadcast_if_dimension_le_one(input1->info()->tensor_shape()));
1517 Iterator input2_it(input2, window.broadcast_if_dimension_le_one(input2->info()->tensor_shape()));
1518 Iterator output_it(output, window);
giuros01154bc1c2019-03-26 17:44:40 +00001519
1520 execute_window_loop(window, [&](const Coordinates &)
1521 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001522 c_mul_F32_F32_F32_n(input1_it.ptr(), input2_it.ptr(), output_it.ptr());
giuros01154bc1c2019-03-26 17:44:40 +00001523 },
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001524 input1_it, input2_it, output_it);
giuros01154bc1c2019-03-26 17:44:40 +00001525}
Manuel Bottini79fa9a22019-02-22 17:54:22 +00001526} // namespace arm_compute