blob: 7ec52f788bf759e977a112fa489ace8094caecb6 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Manuel Bottini79fa9a22019-02-22 17:54:22 +00002 * Copyright (c) 2016-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
Manuel Bottini79fa9a22019-02-22 17:54:22 +000031#include "arm_compute/core/NEON/NEAsymm.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/NEON/NEFixedPoint.h"
Manuel Bottini7bb56c62019-06-26 15:17:09 +010033#include "arm_compute/core/NEON/NESymm.h"
giuros01154bc1c2019-03-26 17:44:40 +000034#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/TensorInfo.h"
Manuel Bottini79fa9a22019-02-22 17:54:22 +000036#include "arm_compute/core/Types.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037#include "arm_compute/core/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
39#include <arm_neon.h>
40#include <climits>
41#include <cmath>
42#include <cstdint>
43#include <cstdlib>
44
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000045#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellodf246182017-07-03 16:25:09 +010046#include <arm_fp16.h> // needed for float16_t
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000047#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +010048
Anthony Barbier6ff3b192017-09-04 18:44:23 +010049namespace arm_compute
50{
51class Coordinates;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010052
53namespace
54{
55const float scale255_constant = 1.f / 255.f;
56const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
57const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f);
58
Michalis Spyrou861f0db2018-02-26 16:47:58 +000059constexpr unsigned int num_elems_processed_per_iteration = 16;
60
Georgios Pinitas631c41a2017-12-06 11:53:03 +000061inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000062{
63 ARM_COMPUTE_UNUSED(overflow_policy);
64 ARM_COMPUTE_UNUSED(rounding_policy);
65
Anthony Barbiereaefd002018-07-20 17:49:35 +010066 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
Pablo Tello52ea9c22019-12-10 11:28:53 +000067 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
68 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
69 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000070 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
71 "Output can only be U8 if both inputs are U8");
72
Pablo Tello52ea9c22019-12-10 11:28:53 +000073 if(is_data_type_quantized(input1->data_type())||
74 is_data_type_quantized(input2->data_type()))
75 {
76 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
77 ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP,"ConvertPolicy cannot be WRAP if datatype is quantized");
78 }
Manuel Bottini79fa9a22019-02-22 17:54:22 +000079
80 if(output->total_size() > 0)
81 {
Manuel Bottini7bb56c62019-06-26 15:17:09 +010082 if(is_data_type_quantized(output->data_type()))
Manuel Bottini79fa9a22019-02-22 17:54:22 +000083 {
84 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output);
85 }
86
87 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
88 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
89 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
90 }
Michalis Spyrou861f0db2018-02-26 16:47:58 +000091
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +000092 if(std::abs(scale - scale255_constant) < 0.00001f)
93 {
94 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
95 }
96 else
97 {
98 ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
99
100 int exponent = 0;
101 const float normalized_mantissa = std::frexp(scale, &exponent);
102
103 // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
104 // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
105 // Moreover, it will be negative as we deal with 1/2^n
106 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
107 }
108
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000109 return Status{};
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000110}
111
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000112inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000113{
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000114 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
115 const ValidRegion &valid_region = broadcast_pair.second;
116
117 // Auto initialize output if not initialized
118 {
Michalis Spyrouebdde652019-07-08 11:52:46 +0100119 ARM_COMPUTE_UNUSED(set_shape_if_empty(*output, input1->tensor_shape()));
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000120
121 if(input1->data_type() == DataType::S16 || input2->data_type() == DataType::S16)
122 {
123 set_format_if_unknown(*output, Format::S16);
124 }
125 else if(input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32)
126 {
127 set_format_if_unknown(*output, Format::F32);
128 }
129 else if(input1->data_type() == DataType::F16 || input2->data_type() == DataType::F16)
130 {
131 set_format_if_unknown(*output, Format::F16);
132 }
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100133 else if(input1->data_type() == DataType::QASYMM8)
134 {
135 set_data_type_if_unknown(*output, DataType::QASYMM8);
136 }
Pablo Tello52ea9c22019-12-10 11:28:53 +0000137 else if(input1->data_type() == DataType::QASYMM8_SIGNED)
138 {
139 set_data_type_if_unknown(*output, DataType::QASYMM8_SIGNED);
140 }
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100141 else if(input1->data_type() == DataType::QSYMM16)
142 {
143 set_data_type_if_unknown(*output, DataType::QSYMM16);
144 }
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000145 }
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000146
147 // Configure kernel window
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000148 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration));
149 Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
150 Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
151
152 AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration);
153 AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000154 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
155
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000156 bool window_changed = update_window_and_padding(win_input1, input1_access)
157 || update_window_and_padding(win_input2, input2_access)
158 || update_window_and_padding(win, output_access);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000159
160 output_access.set_valid_region(win, valid_region);
161
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000162 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000163 return std::make_pair(err, win);
164}
165
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100166/* Scales a given vector by 1/255.
167 *
168 * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
169 *
170 * @param in Input vector to scale.
171 * @return Scaled output rounded to nearest (round half up).
172 */
173inline int32x4_t scale255_S32_S32(int32x4_t in)
174{
175 // Scale
176 const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
177 // Round to nearest (round half up)
178 // Add +0.5 for all values
179 // Afterwards vcvt rounds toward zero
180 return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
181}
182
183inline uint16x8_t scale255_U16_U16(uint16x8_t in)
184{
185 const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
186 const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
187 return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
188}
189
Michalis Spyrou6bff1952019-10-02 17:22:11 +0100190inline void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n_opt(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr,
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100191 float32x4_t input1_vscale, int32x4_t input1_voffset, float32x4_t input2_vscale, int32x4_t input2_voffset, float32x4_t output_voffset, float32x4_t vinvscale)
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000192{
193 const auto input1 = static_cast<const qasymm8_t *__restrict>(input1_ptr);
194 const auto input2 = static_cast<const qasymm8_t *__restrict>(input2_ptr);
195 const auto output = static_cast<qasymm8_t *__restrict>(output_ptr);
196
197 const qasymm8x16_t input1_q = vld1q_u8(input1);
198 const qasymm8x16_t input2_q = vld1q_u8(input2);
199
200 // Dequantitize inputs
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100201 float32x4x4_t in1_f32x4x4;
202 float32x4x4_t in2_f32x4x4;
203 in1_f32x4x4.val[0] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(input1_q))))), input1_voffset)), input1_vscale);
204 in1_f32x4x4.val[1] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(input1_q))))), input1_voffset)), input1_vscale);
205 in1_f32x4x4.val[2] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(input1_q))))), input1_voffset)), input1_vscale);
206 in1_f32x4x4.val[3] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(input1_q))))), input1_voffset)), input1_vscale);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000207
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100208 in2_f32x4x4.val[0] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(input2_q))))), input2_voffset)), input2_vscale);
209 in2_f32x4x4.val[1] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(input2_q))))), input2_voffset)), input2_vscale);
210 in2_f32x4x4.val[2] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(input2_q))))), input2_voffset)), input2_vscale);
211 in2_f32x4x4.val[3] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(input2_q))))), input2_voffset)), input2_vscale);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000212
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100213 float32x4x4_t out_f32x4x4;
214 out_f32x4x4.val[0] = vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]);
215 out_f32x4x4.val[1] = vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]);
216 out_f32x4x4.val[2] = vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]);
217 out_f32x4x4.val[3] = vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]);
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000218
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100219 int32x4x4_t rf;
220#ifdef __aarch64__
221 rf.val[0] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[0], vinvscale));
222 rf.val[1] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[1], vinvscale));
223 rf.val[2] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[2], vinvscale));
224 rf.val[3] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[3], vinvscale));
225#else //__aarch64__
226 rf.val[0] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[0], vinvscale));
227 rf.val[1] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[1], vinvscale));
228 rf.val[2] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[2], vinvscale));
229 rf.val[3] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[3], vinvscale));
230#endif //__aarch64__
231 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1])));
232 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[2]), vqmovn_s32(rf.val[3])));
233
234 vst1q_u8(output, vcombine_u8(pa, pb));
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000235}
236
Pablo Tello52ea9c22019-12-10 11:28:53 +0000237inline void mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n(
238 const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr,
239 float scale, const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info,
240 const UniformQuantizationInfo &output_qua_info)
241
242{
243 const auto input1 = static_cast<const qasymm8_signed_t *__restrict>(input1_ptr);
244 const auto input2 = static_cast<const qasymm8_signed_t *__restrict>(input2_ptr);
245 const auto output = static_cast<qasymm8_signed_t *__restrict>(output_ptr);
246 const qasymm8x16_signed_t input1_q = vld1q_s8(input1);
247 const qasymm8x16_signed_t input2_q = vld1q_s8(input2);
248 // Dequantitize inputs
249 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
250 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
251 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
252 const float32x4x4_t out_f32x4x4 =
253 {
254 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
255 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
256 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
257 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
258 };
259 const int8x16_t result = vquantize_signed(out_f32x4x4, tmp_qua_info);
260 vst1q_s8(output, result);
261}
262
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100263void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
264 const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info, const UniformQuantizationInfo &output_qua_info)
265{
266 const auto input1 = static_cast<const qsymm16_t *__restrict>(input1_ptr);
267 const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr);
268 const auto output = static_cast<qsymm16_t *__restrict>(output_ptr);
269
270 const qsymm16x8x2_t input1_q = vld2q_s16(input1);
271 const qsymm16x8x2_t input2_q = vld2q_s16(input2);
272
273 // Dequantitize inputs
274 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
275 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
276
277 const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
278
279 const float32x4x4_t out_f32x4x4 =
280 {
281 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
282 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
283 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
284 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
285 };
286
287 const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
288 vst2q_s16(output, result);
289}
290
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100291template <bool is_scale255, bool is_sat>
292void mul_U8_U8_U8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
293{
294 const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
295 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
296 const auto output = static_cast<uint8_t *__restrict>(output_ptr);
297
298 const uint8x16_t ta1 = vld1q_u8(input1);
299 const uint8x16_t ta2 = vld1q_u8(input2);
300
301 uint16x8_t tmp1_high = vmovl_u8(vget_high_u8(ta1));
302 const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
303 uint16x8_t tmp1_low = vmovl_u8(vget_low_u8(ta1));
304 const uint16x8_t tmp2_low = vmovl_u8(vget_low_u8(ta2));
305
306 tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
307 tmp1_low = vmulq_u16(tmp1_low, tmp2_low);
308
309 if(is_scale255)
310 {
311 tmp1_high = scale255_U16_U16(tmp1_high);
312 tmp1_low = scale255_U16_U16(tmp1_low);
313 }
314 else
315 {
316 const int16x8_t vn = vdupq_n_s16(-n);
317
318 if(is_sat)
319 {
320 tmp1_high = vqshlq_u16(tmp1_high, vn);
321 tmp1_low = vqshlq_u16(tmp1_low, vn);
322 }
323 else
324 {
325 tmp1_high = vshlq_u16(tmp1_high, vn);
326 tmp1_low = vshlq_u16(tmp1_low, vn);
327 }
328 }
329
330 if(is_sat)
331 {
332 vst1q_u8(output, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
333 }
334 else
335 {
336 vst1q_u8(output, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
337 }
338}
339
340template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100341inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
342{
343 int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
344 const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
345 int32x4_t tmp1_low = vmovl_s16(vget_low_s16(input1));
346 const int32x4_t tmp2_low = vmovl_s16(vget_low_s16(input2));
347
348 tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
349 tmp1_low = vmulq_s32(tmp1_low, tmp2_low);
350
351 if(is_scale255)
352 {
353 tmp1_high = scale255_S32_S32(tmp1_high);
354 tmp1_low = scale255_S32_S32(tmp1_low);
355 }
356 else
357 {
358 // Right shift amount
359 const int32x4_t vn = vdupq_n_s32(-n);
360 // Left shift amount
361 const int32x4_t vnl = vdupq_n_s32(n);
362 // Calculate conversion bit
363 const uint32x4_t tmp1_high_u = vreinterpretq_u32_s32(tmp1_high);
364 const uint32x4_t tmp1_low_u = vreinterpretq_u32_s32(tmp1_low);
365 const uint32x4_t sign_high = vshrq_n_u32(tmp1_high_u, 31);
366 const uint32x4_t sign_low = vshrq_n_u32(tmp1_low_u, 31);
367 const int32x4_t sign_high_s = vreinterpretq_s32_u32(sign_high);
368 const int32x4_t sign_low_s = vreinterpretq_s32_u32(sign_low);
369 const int32x4_t convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
370 const int32x4_t convert_low = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
371 if(is_sat)
372 {
373 tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
374 tmp1_low = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
375 }
376 else
377 {
378 tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
379 tmp1_low = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
380 }
381 }
382
383 if(is_sat)
384 {
385 return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
386 }
387 else
388 {
389 return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
390 }
391}
392
393template <bool is_scale255, bool is_sat>
394inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
395{
396 const int16x8x2_t result =
397 {
398 {
399 // First 8 elements
400 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
401 // Second 8 elements
402 mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
403 }
404 };
405
406 return result;
407}
408
409template <bool is_scale255, bool is_sat>
410void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
411{
412 const auto input1 = static_cast<const int16_t *__restrict>(input1_ptr);
413 const auto input2 = static_cast<const int16_t *__restrict>(input2_ptr);
414 const auto output = static_cast<int16_t *__restrict>(output_ptr);
415
416 const int16x8x2_t ta1 = vld2q_s16(input1);
417 const int16x8x2_t ta2 = vld2q_s16(input2);
418 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
419
420 vst2q_s16(output, result);
421}
422
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100423void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
424{
425 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
426 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
427 const auto output = static_cast<float *__restrict>(output_ptr);
428
429 const float32x4x4_t ta1 = vld4q_f32(input1);
430 const float32x4x4_t ta2 = vld4q_f32(input2);
431 const float32x4_t scale_vec = vdupq_n_f32(scale);
432 const float32x4x4_t result =
433 {
434 {
435 vmulq_f32(vmulq_f32(ta1.val[0], ta2.val[0]), scale_vec),
436 vmulq_f32(vmulq_f32(ta1.val[1], ta2.val[1]), scale_vec),
437 vmulq_f32(vmulq_f32(ta1.val[2], ta2.val[2]), scale_vec),
438 vmulq_f32(vmulq_f32(ta1.val[3], ta2.val[3]), scale_vec)
439 }
440 };
441 vst4q_f32(output, result);
442}
443
giuros01154bc1c2019-03-26 17:44:40 +0000444void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr)
445{
446 const auto input1 = static_cast<const float *__restrict>(input1_ptr);
447 const auto input2 = static_cast<const float *__restrict>(input2_ptr);
448 const auto output = static_cast<float *__restrict>(output_ptr);
449
450 const float32x4_t a = wrapper::vloadq(input1);
451 float32x4_t b = wrapper::vloadq(input2);
452
453 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
454
455 const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
456 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
457 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
458 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
459 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
460
461 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
462 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
463
464 float32x4_t res = wrapper::vmul(tmp0, b);
465
466 b = wrapper::vrev64(b);
467 b = wrapper::vmul(b, mask);
468
469 res = wrapper::vmla(res, tmp1, b);
470 wrapper::vstore(output, res);
471}
472
Pablo Tellodf246182017-07-03 16:25:09 +0100473void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
474{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000475#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tellodf246182017-07-03 16:25:09 +0100476 const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr);
477 const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr);
478 const auto output = static_cast<float16_t *__restrict>(output_ptr);
479 const float16x8x2_t ta1 = vld2q_f16(input1);
480 const float16x8x2_t ta2 = vld2q_f16(input2);
481 const float16x8_t scale_vec = vdupq_n_f16(scale);
482 const float16x8x2_t result =
483 {
484 {
485 vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
486 vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
487 }
488 };
489 vst2q_f16(output, result);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000490#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Georgios Pinitas30f02152017-09-27 11:20:48 +0100491 ARM_COMPUTE_UNUSED(input1_ptr);
492 ARM_COMPUTE_UNUSED(input2_ptr);
493 ARM_COMPUTE_UNUSED(output_ptr);
494 ARM_COMPUTE_UNUSED(scale);
Pablo Tellodf246182017-07-03 16:25:09 +0100495 ARM_COMPUTE_ERROR("Not supported. Recompile the library with arch=arm64-v8.2-a.");
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000496#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tellodf246182017-07-03 16:25:09 +0100497}
498
499template <bool is_scale255, bool is_sat>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500void mul_U8_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
501{
502 const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
503 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
504 const auto output = static_cast<int16_t *__restrict>(output_ptr);
505
506 const uint8x16_t bv = vld1q_u8(input2);
507 const uint8x16_t av = vld1q_u8(input1);
508
509 uint16x8_t tmp_low = vmovl_u8(vget_low_u8(av));
510 uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
511 tmp_low = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
512 tmp_high = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
513
514 if(is_scale255)
515 {
516 tmp_low = scale255_U16_U16(tmp_low);
517 tmp_high = scale255_U16_U16(tmp_high);
518 }
519 else
520 {
521 const int16x8_t vn = vdupq_n_s16(-n);
522
523 if(is_sat)
524 {
525 tmp_low = vqshlq_u16(tmp_low, vn);
526 tmp_high = vqshlq_u16(tmp_high, vn);
527 }
528 else
529 {
530 tmp_low = vshlq_u16(tmp_low, vn);
531 tmp_high = vshlq_u16(tmp_high, vn);
532 }
533 }
534
535 if(is_sat)
536 {
537 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
538
539 tmp_low = vminq_u16(tmp_low, max);
540 tmp_high = vminq_u16(tmp_high, max);
541 }
542
543 vst1q_s16(output, vreinterpretq_s16_u16(tmp_low));
544 vst1q_s16(output + 8, vreinterpretq_s16_u16(tmp_high));
545}
546
547template <bool is_scale255, bool is_sat>
548void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
549{
550 const auto input1 = static_cast<const int16_t *__restrict>(input1_ptr);
551 const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
552 const auto output = static_cast<int16_t *__restrict>(output_ptr);
553
554 const int16x8x2_t ta1 = vld2q_s16(input1);
555 const uint8x8x2_t ta2u = vld2_u8(input2);
556 const int16x8x2_t ta2 =
557 {
558 {
559 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
560 vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
561 }
562 };
563
564 const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
565
566 vst2q_s16(output, result);
567}
568
569template <bool is_scale255, bool is_sat>
570void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
571{
572 // Simply swap the two input buffers
573 mul_S16_U8_S16_n<is_scale255, is_sat>(input2_ptr, input1_ptr, output_ptr, n);
574}
575} // namespace
576
577NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100578 : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }, _run_optimized_qasymm8(false)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100579{
580}
581
582void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
583{
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000584 ARM_COMPUTE_UNUSED(rounding_policy);
Georgios Pinitasf0dea702017-07-03 18:17:28 +0100585 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
586
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000587 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), scale, overflow_policy, rounding_policy));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100588
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000589 // Configure kernel window
590 auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
591 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
592
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100593 _input1 = input1;
594 _input2 = input2;
595 _output = output;
596 _scale = scale;
597 _scale_exponent = 0;
598 _func_quantized = nullptr;
599 _func_int = nullptr;
600 _func_float = nullptr;
601 _run_optimized_qasymm8 = false;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100602
603 bool is_scale_255 = false;
604 // Check and validate scaling factor
605 if(std::abs(scale - scale255_constant) < 0.00001f)
606 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100607 is_scale_255 = true;
608 }
609 else
610 {
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000611 int exponent = 0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100612
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000613 std::frexp(scale, &exponent);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100614
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000615 // Store the positive exponent. We know that we compute 1/2^n
616 // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
617 _scale_exponent = std::abs(exponent - 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100618 }
619
620 const DataType dt_input1 = input1->info()->data_type();
621 const DataType dt_input2 = input2->info()->data_type();
622 const DataType dt_output = output->info()->data_type();
623 const bool is_sat = (overflow_policy == ConvertPolicy::SATURATE);
624
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000625 if(dt_input1 == DataType::QASYMM8 && dt_input2 == DataType::QASYMM8)
626 {
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100627 _run_optimized_qasymm8 = true;
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100628 }
Pablo Tello52ea9c22019-12-10 11:28:53 +0000629 else if(dt_input1 == DataType::QASYMM8_SIGNED && dt_input2 == DataType::QASYMM8_SIGNED)
630 {
631 _func_quantized = &mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n;
632 }
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100633 else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16)
634 {
635 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n;
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000636 }
637 else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100638 {
639 if(is_scale_255)
640 {
641 _func_int = is_sat ? &mul_U8_U8_U8_n<true, true> : &mul_U8_U8_U8_n<true, false>;
642 }
643 else
644 {
645 _func_int = is_sat ? &mul_U8_U8_U8_n<false, true> : &mul_U8_U8_U8_n<false, false>;
646 }
647 }
648 else if(DataType::S16 == dt_input1 && DataType::S16 == dt_input2 && DataType::S16 == dt_output)
649 {
650 if(is_scale_255)
651 {
652 _func_int = is_sat ? &mul_S16_S16_S16_n<true, true> : &mul_S16_S16_S16_n<true, false>;
653 }
654 else
655 {
656 _func_int = is_sat ? &mul_S16_S16_S16_n<false, true> : &mul_S16_S16_S16_n<false, false>;
657 }
658 }
659 else if(DataType::S16 == dt_input1 && DataType::U8 == dt_input2 && DataType::S16 == dt_output)
660 {
661 if(is_scale_255)
662 {
663 _func_int = is_sat ? &mul_S16_U8_S16_n<true, true> : &mul_S16_U8_S16_n<true, false>;
664 }
665 else
666 {
667 _func_int = is_sat ? &mul_S16_U8_S16_n<false, true> : &mul_S16_U8_S16_n<false, false>;
668 }
669 }
670 else if(DataType::U8 == dt_input1 && DataType::S16 == dt_input2 && DataType::S16 == dt_output)
671 {
672 if(is_scale_255)
673 {
674 _func_int = is_sat ? &mul_U8_S16_S16_n<true, true> : &mul_U8_S16_S16_n<true, false>;
675 }
676 else
677 {
678 _func_int = is_sat ? &mul_U8_S16_S16_n<false, true> : &mul_U8_S16_S16_n<false, false>;
679 }
680 }
681 else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::S16 == dt_output)
682 {
683 if(is_scale_255)
684 {
685 _func_int = is_sat ? &mul_U8_U8_S16_n<true, true> : &mul_U8_U8_S16_n<true, false>;
686 }
687 else
688 {
689 _func_int = is_sat ? &mul_U8_U8_S16_n<false, true> : &mul_U8_U8_S16_n<false, false>;
690 }
691 }
Pablo Tellodf246182017-07-03 16:25:09 +0100692 else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output)
693 {
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000694 _func_float = &mul_F16_F16_F16_n;
Pablo Tellodf246182017-07-03 16:25:09 +0100695 _func_int = nullptr;
696 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100697 else if(DataType::F32 == dt_input1 && DataType::F32 == dt_input2 && DataType::F32 == dt_output)
698 {
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000699 _func_float = &mul_F32_F32_F32_n;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100700 _func_int = nullptr;
701 }
702 else
703 {
704 ARM_COMPUTE_ERROR("You called with the wrong img formats");
705 }
706
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000707 INEKernel::configure(win_config.second);
708}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100709
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000710Status NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
711 RoundingPolicy rounding_policy)
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000712{
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000713 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
Ioan-Cristian Szabo754e9522017-11-28 18:29:43 +0000714 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
715 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100716
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000717 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100718}
719
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100720void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100721{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100722 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100723 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
724 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
725
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000726 const TensorShape &in_shape1 = _input1->info()->tensor_shape();
727 const TensorShape &in_shape2 = _input2->info()->tensor_shape();
728 const TensorShape &out_shape = _output->info()->tensor_shape();
729
730 bool can_collapse = true;
731 if(std::min(in_shape1.total_size(), in_shape2.total_size()) > 1)
732 {
733 can_collapse = (std::min(in_shape1.num_dimensions(), in_shape2.num_dimensions()) > Window::DimZ);
734 for(size_t d = Window::DimZ; can_collapse && (d < out_shape.num_dimensions()); ++d)
735 {
736 can_collapse = (in_shape1[d] == in_shape2[d]);
737 }
738 }
739
740 bool has_collapsed = false;
741 Window collapsed = can_collapse ? window.collapse_if_possible(INEKernel::window(), Window::DimZ, &has_collapsed) : window;
742
743 const TensorShape &in_shape1_collapsed = has_collapsed ? in_shape1.collapsed_from(Window::DimZ) : in_shape1;
744 const TensorShape &in_shape2_collapsed = has_collapsed ? in_shape2.collapsed_from(Window::DimZ) : in_shape2;
745
746 Window slice = collapsed.first_slice_window_3D();
747 Window slice_input1 = slice.broadcast_if_dimension_le_one(in_shape1_collapsed);
748 Window slice_input2 = slice.broadcast_if_dimension_le_one(in_shape2_collapsed);
749
750 Iterator input1(_input1, slice_input1);
751 Iterator input2(_input2, slice_input2);
752 Iterator output(_output, slice);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100753
Manuel Bottini7bb56c62019-06-26 15:17:09 +0100754 if(is_data_type_quantized(_input1->info()->data_type()))
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000755 {
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100756 if(_run_optimized_qasymm8)
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000757 {
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100758 const int32x4_t input1_voffset = vdupq_n_s32(_input1->info()->quantization_info().uniform().offset);
759 const float32x4_t input1_vscale = vdupq_n_f32(_input1->info()->quantization_info().uniform().scale);
760 const int32x4_t input2_voffset = vdupq_n_s32(_input2->info()->quantization_info().uniform().offset);
761 const float32x4_t input2_vscale = vdupq_n_f32(_input2->info()->quantization_info().uniform().scale);
762 const float32x4_t output_voffset = vdupq_n_f32(static_cast<float>(_output->info()->quantization_info().uniform().offset));
763 const float output_scale = _output->info()->quantization_info().uniform().scale;
764 const float32x4_t vinvscale = vdupq_n_f32(1.f / (output_scale / _scale));
765
766 execute_window_loop(collapsed, [&](const Coordinates &)
767 {
Michalis Spyrou6bff1952019-10-02 17:22:11 +0100768 mul_saturate_QASYMM8_QASYMM8_QASYMM8_n_opt(input1.ptr(), input2.ptr(), output.ptr(),
Gian Marco Iodiceb19c55d2019-08-30 17:50:15 +0100769 input1_vscale, input1_voffset, input2_vscale, input2_voffset, output_voffset, vinvscale);
770 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
771 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
772 },
773 input1, input2, output);
774 }
775 else
776 {
777 execute_window_loop(collapsed, [&](const Coordinates &)
778 {
779 (*_func_quantized)(input1.ptr(), input2.ptr(), output.ptr(), _scale,
780 _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform());
781 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
782 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
783 },
784 input1, input2, output);
785 }
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000786 }
787 else if(_func_int != nullptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100788 {
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100789 execute_window_loop(collapsed, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100790 {
791 (*_func_int)(input1.ptr(), input2.ptr(), output.ptr(), _scale_exponent);
Michalis Spyrouebdde652019-07-08 11:52:46 +0100792 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
793 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100794 },
795 input1, input2, output);
796 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100797 else
798 {
799 ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100800 execute_window_loop(collapsed, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100801 {
802 (*_func_float)(input1.ptr(), input2.ptr(), output.ptr(), _scale);
Michalis Spyrouebdde652019-07-08 11:52:46 +0100803 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
804 ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100805 },
806 input1, input2, output);
807 }
808}
Michalis Spyrou861f0db2018-02-26 16:47:58 +0000809
810BorderSize NEPixelWiseMultiplicationKernel::border_size() const
811{
812 const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
813 const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100814 return BorderSize{ 0, border, 0, 0 };
Anthony Barbiereaefd002018-07-20 17:49:35 +0100815}
giuros01154bc1c2019-03-26 17:44:40 +0000816
817namespace
818{
819constexpr unsigned int num_elems_processed_per_iteration_complex = 2;
820
821Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
822{
823 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32);
824 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32);
825
826 const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
827
828 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
829
830 // Validate in case of configured output
831 if(output->total_size() > 0)
832 {
833 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32);
834 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
835 }
836
837 return Status{};
838}
839
840std::pair<Status, Window> validate_and_configure_window_complex(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
841{
842 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
843 const TensorShape &out_shape = broadcast_pair.first;
844 const ValidRegion &valid_region = broadcast_pair.second;
845
846 // Auto initialize output if not initialized
847 const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type());
848 auto_init_if_empty(*output, out_info);
849
850 Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration_complex));
851 Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
852 Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
853
854 AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration_complex);
855 AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_complex);
856 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_complex);
857
858 bool window_changed = update_window_and_padding(win_input1, input1_access)
859 || update_window_and_padding(win_input2, input2_access)
860 || update_window_and_padding(win, output_access);
861
862 output_access.set_valid_region(win, valid_region);
863
864 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
865 return std::make_pair(err, win);
866}
867} // namespace
868
869NEComplexPixelWiseMultiplicationKernel::NEComplexPixelWiseMultiplicationKernel()
870 : _input1(nullptr), _input2(nullptr), _output(nullptr)
871{
872}
873
874void NEComplexPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
875{
876 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
877 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1->info(), input2->info(), output->info()));
878
879 // Configure kernel window
880 auto win_config = validate_and_configure_window_complex(input1->info(), input2->info(), output->info());
881 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
882
883 _input1 = input1;
884 _input2 = input2;
885 _output = output;
886
887 // Create kernel
888 INEKernel::configure(win_config.second);
889}
890
891Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
892{
893 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
894 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output));
895 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_complex(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
896
897 return Status{};
898}
899
900void NEComplexPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo &info)
901{
902 ARM_COMPUTE_UNUSED(info);
903 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
904 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
905
906 Iterator input1(_input1, window.broadcast_if_dimension_le_one(_input1->info()->tensor_shape()));
907 Iterator input2(_input2, window.broadcast_if_dimension_le_one(_input2->info()->tensor_shape()));
908 Iterator output(_output, window);
909
910 execute_window_loop(window, [&](const Coordinates &)
911 {
912 c_mul_F32_F32_F32_n(input1.ptr(), input2.ptr(), output.ptr());
913 },
914 input1, input2, output);
915}
916
917BorderSize NEComplexPixelWiseMultiplicationKernel::border_size() const
918{
919 const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
920 const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration_complex - 1U, replicateSize);
921 return { 0, border, 0, 0 };
922}
Manuel Bottini79fa9a22019-02-22 17:54:22 +0000923} // namespace arm_compute