blob: db4f5923bc3de483d441479108f0ce787be085f4 [file] [log] [blame]
giuros0192fd9432018-12-03 17:30:00 +00001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2020 Arm Limited.
giuros0192fd9432018-12-03 17:30:00 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEElementwiseOperationKernel.h"
25
26#include "arm_compute/core/CPP/Validate.h"
giuros0192fd9432018-12-03 17:30:00 +000027#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
giuros0192fd9432018-12-03 17:30:00 +000029#include "arm_compute/core/NEON/NEAsymm.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/wrapper/wrapper.h"
giuros0192fd9432018-12-03 17:30:00 +000032
giuros0192fd9432018-12-03 17:30:00 +000033#include <arm_neon.h>
giuros0192fd9432018-12-03 17:30:00 +000034#include <map>
giuros0192fd9432018-12-03 17:30:00 +000035
36namespace arm_compute
37{
giuros0192fd9432018-12-03 17:30:00 +000038namespace
39{
40float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
41{
42 qasymm8x16_t x = vld1q_u8(input1_ptr);
43 const float32x4x4_t out =
44 {
45 {
46 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
47 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
48 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
49 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
50 }
51 };
52 return out;
53}
54
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000055float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
56{
57 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
58 const float32x4x4_t out =
59 {
60 {
61 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
62 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
63 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
64 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
65 }
66 };
67 return out;
68}
69
George Wortd88590f2018-12-12 17:39:58 +000070void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
71{
72 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
73 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
74 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
75}
76
77void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
78{
79 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
80 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
81 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
82}
83
giuros0192fd9432018-12-03 17:30:00 +000084void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
85{
86 int32x4x4_t out =
87 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +000088 {
89 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
90 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
91 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
92 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
93 }
giuros0192fd9432018-12-03 17:30:00 +000094 };
George Wortd88590f2018-12-12 17:39:58 +000095 store_quantized(output_ptr, out);
giuros0192fd9432018-12-03 17:30:00 +000096}
97
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000098void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
99{
100 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
101 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
102 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
103}
104
105void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
106{
107 int32x4x4_t out =
108 {
109 {
110 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
111 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
112 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
113 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
114 }
115 };
116 store_quantized_signed(output_ptr, out);
117}
118
giuros0192fd9432018-12-03 17:30:00 +0000119template <ArithmeticOperation op, typename ScalarType>
George Wortd88590f2018-12-12 17:39:58 +0000120inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
giuros0192fd9432018-12-03 17:30:00 +0000121{
122 auto res = ScalarType(0);
123
124 switch(op)
125 {
126 case ArithmeticOperation::MAX:
127 res = std::max(a, b);
128 break;
129 case ArithmeticOperation::MIN:
130 res = std::min(a, b);
131 break;
132 case ArithmeticOperation::SQUARED_DIFF:
133 {
134 res = (a - b) * (a - b);
135 break;
136 }
giuros01d5134362019-05-14 16:12:53 +0100137 case ArithmeticOperation::PRELU:
138 {
139 res = (a > 0 ? a : a * b);
140 break;
141 }
George Worta1e7e282019-01-15 11:00:29 +0000142 case ArithmeticOperation::DIV:
143 {
144 res = a / b;
145 break;
146 }
Usama Arif81e671e2019-05-13 13:33:14 +0100147 case ArithmeticOperation::POWER:
148 {
149 res = std::pow(a, b);
150 break;
151 }
giuros0192fd9432018-12-03 17:30:00 +0000152 default:
153 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
154 }
155 return res;
156}
157
George Wortd88590f2018-12-12 17:39:58 +0000158template <ArithmeticOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100159inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000160{
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100161 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
George Wortd88590f2018-12-12 17:39:58 +0000162}
163
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000164template <ArithmeticOperation op>
165inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
166{
167 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
168}
169
giuros01d5134362019-05-14 16:12:53 +0100170template <ArithmeticOperation op, typename VectorType>
171inline typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
giuros0192fd9432018-12-03 17:30:00 +0000172{
giuros01d5134362019-05-14 16:12:53 +0100173 using vec_type = typename VectorType::type;
174 using scalar_type = typename VectorType::scalar_type;
175 using tag_type = typename VectorType::tag_type;
176
177 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
giuros0192fd9432018-12-03 17:30:00 +0000178
179 switch(op)
180 {
181 case ArithmeticOperation::MAX:
182 res = wrapper::vmax(a, b);
183 break;
184 case ArithmeticOperation::MIN:
185 res = wrapper::vmin(a, b);
186 break;
187 case ArithmeticOperation::SQUARED_DIFF:
188 {
giuros01d5134362019-05-14 16:12:53 +0100189 const vec_type tmp = wrapper::vsub(a, b);
190 res = wrapper::vmul(tmp, tmp);
giuros0192fd9432018-12-03 17:30:00 +0000191 break;
192 }
giuros01d5134362019-05-14 16:12:53 +0100193 case ArithmeticOperation::PRELU:
194 {
195 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
196 const vec_type tmp = wrapper::vmul(a, b);
197 const auto gt = wrapper::vcgt(a, zero);
198
199 res = wrapper::vbsl(gt, a, tmp);
200 break;
201 }
202
giuros0192fd9432018-12-03 17:30:00 +0000203 default:
204 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
205 }
206
207 return res;
208}
209
George Worta1e7e282019-01-15 11:00:29 +0000210template <>
giuros01d5134362019-05-14 16:12:53 +0100211inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000212{
213 return wrapper::vdiv(a, b);
214}
215
Usama Arif81e671e2019-05-13 13:33:14 +0100216template <>
giuros01d5134362019-05-14 16:12:53 +0100217inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100218{
219 return wrapper::vpow(a, b);
220}
221
George Worta1e7e282019-01-15 11:00:29 +0000222#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
223template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100224inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000225{
226 return wrapper::vdiv(a, b);
227}
Usama Arif81e671e2019-05-13 13:33:14 +0100228
229template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100230inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100231{
232 return wrapper::vpow(a, b);
233}
George Worta1e7e282019-01-15 11:00:29 +0000234#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
235
giuros0192fd9432018-12-03 17:30:00 +0000236template <ArithmeticOperation op>
George Wortd88590f2018-12-12 17:39:58 +0000237inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
giuros0192fd9432018-12-03 17:30:00 +0000238{
giuros01d5134362019-05-14 16:12:53 +0100239 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
giuros0192fd9432018-12-03 17:30:00 +0000240 float32x4x4_t out =
241 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000242 {
giuros01d5134362019-05-14 16:12:53 +0100243 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
244 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
245 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
246 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000247 }
giuros0192fd9432018-12-03 17:30:00 +0000248 };
249 return out;
250}
251
giuros01d5134362019-05-14 16:12:53 +0100252template <ArithmeticOperation op, typename ScalarType, typename VectorType>
253inline typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
George Wortd88590f2018-12-12 17:39:58 +0000254{
giuros01d5134362019-05-14 16:12:53 +0100255 using tag_type = typename VectorType::tag_type;
256 using vec_type = typename VectorType::type;
257
258 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
259 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
George Wortd88590f2018-12-12 17:39:58 +0000260}
261
262template <ComparisonOperation op, typename InputScalarType>
263inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
264{
265 bool res = false;
266
267 switch(op)
268 {
269 case ComparisonOperation::Equal:
270 res = (a == b);
271 break;
272 case ComparisonOperation::NotEqual:
273 res = (a != b);
274 break;
275 case ComparisonOperation::Greater:
276 res = (a > b);
277 break;
278 case ComparisonOperation::GreaterEqual:
279 res = (a >= b);
280 break;
281 case ComparisonOperation::Less:
282 res = (a < b);
283 break;
284 case ComparisonOperation::LessEqual:
285 res = (a <= b);
286 break;
287 default:
288 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
289 }
290 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
291}
292
293template <ComparisonOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100294inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000295{
296 ARM_COMPUTE_UNUSED(qinfo);
297 return elementwise_comp_op_scalar<op>(a, b);
298}
299
300template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
301inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
302{
303 OutputVectorType res = { 0, 0, 0, 0 };
304
305 switch(op)
306 {
307 case ComparisonOperation::Equal:
308 res = wrapper::vceq(a, b);
309 break;
310 case ComparisonOperation::NotEqual:
311 res = wrapper::vnot(wrapper::vceq(a, b));
312 break;
313 case ComparisonOperation::Greater:
314 res = wrapper::vcgt(a, b);
315 break;
316 case ComparisonOperation::GreaterEqual:
317 res = wrapper::vcge(a, b);
318 break;
319 case ComparisonOperation::Less:
320 res = wrapper::vcgt(b, a);
321 break;
322 case ComparisonOperation::LessEqual:
323 res = wrapper::vcge(b, a);
324 break;
325 default:
326 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
327 }
328
329 return res;
330}
331
332template <ComparisonOperation op>
333inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
334{
335 uint32x4x4_t out =
336 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000337 {
338 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
339 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
340 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
341 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
342 }
George Wortd88590f2018-12-12 17:39:58 +0000343 };
344 return out;
345}
346
347template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
348inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
349{
350 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
351 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
352}
353
354template <ArithmeticOperation op, typename ScalarType, typename VectorType>
355inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
356 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
357{
358 int x = window_start_x;
359 for(; x <= (window_end_x - window_step_x); x += window_step_x)
360 {
361 const auto a = wrapper::vloadq(input1_ptr + x);
362 const auto b = wrapper::vloadq(input2_ptr + x);
giuros01d5134362019-05-14 16:12:53 +0100363 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
George Wortd88590f2018-12-12 17:39:58 +0000364 }
365 return x;
366}
367
368template <ArithmeticOperation op>
369inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
370 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
371 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
372 float32x4_t voffseto, float32x4_t invvscaleo)
373{
374 int x = window_start_x;
375 for(; x <= (window_end_x - window_step_x); x += window_step_x)
376 {
377 // Get inputs and compute output
378 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
379 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
380 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
381 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
382 }
383 return x;
384}
385
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000386template <ArithmeticOperation op>
387inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
388 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
389 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
390 float32x4_t voffseto, float32x4_t invvscaleo)
391{
392 int x = window_start_x;
393 for(; x <= (window_end_x - window_step_x); x += window_step_x)
394 {
395 // Get inputs and compute output
396 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
397 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
398 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
399 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
400 }
401 return x;
402}
403
George Wortd88590f2018-12-12 17:39:58 +0000404template <ArithmeticOperation op, typename ScalarType, typename VectorType>
405inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
406 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
407{
408 int x = window_start_x;
409 for(; x <= (window_end_x - window_step_x); x += window_step_x)
410 {
411 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
giuros01d5134362019-05-14 16:12:53 +0100412 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
George Wortd88590f2018-12-12 17:39:58 +0000413 }
414 return x;
415}
416
417template <ArithmeticOperation op>
418inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
419 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
420 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
421 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
422{
423 int x = window_start_x;
424 for(; x <= (window_end_x - window_step_x); x += window_step_x)
425 {
426 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
427 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
428 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
429 }
430 return x;
431}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000432template <ArithmeticOperation op>
433inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
434 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
435 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
436 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
437{
438 int x = window_start_x;
439 for(; x <= (window_end_x - window_step_x); x += window_step_x)
440 {
441 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
442 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
443 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
444 }
445 return x;
446}
George Wortd88590f2018-12-12 17:39:58 +0000447
448template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
Michele Di Giorgio1c76c1d2020-08-28 13:25:31 +0100449inline int elementwise_comp_op_8_loop(int window_start_x, int window_end_x, int window_step_x,
450 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
451{
452 int x = window_start_x;
453 for(; x <= (window_end_x - window_step_x); x += window_step_x)
454 {
455 const auto a = wrapper::vloadq(input1_ptr + x);
456 const auto b = wrapper::vloadq(input2_ptr + x);
457 const auto res = elementwise_comp_op<op, InputVectorType, uint8x16_t>(a, b);
458 wrapper::vstore(output_ptr + x, res);
459 }
460 return x;
461}
462
463template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
George Wortd88590f2018-12-12 17:39:58 +0000464inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
465 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
466{
467 int x = window_start_x;
468 for(; x <= (window_end_x - window_step_x); x += window_step_x)
469 {
470 const auto a = wrapper::vloadq(input1_ptr + x);
471 const auto b = wrapper::vloadq(input2_ptr + x);
472 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
473 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
474 }
475 return x;
476}
477
478template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
479inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
480 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
481{
482 int x = window_start_x;
483 for(; x <= (window_end_x - window_step_x); x += window_step_x)
484 {
485 auto a = wrapper::vloadq(input1_ptr + x);
486 auto b = wrapper::vloadq(input2_ptr + x);
487 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
488 a = wrapper::vloadq(input1_ptr + x + 4);
489 b = wrapper::vloadq(input2_ptr + x + 4);
490 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
491 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
492 }
493 if(x <= window_end_x - 4)
494 {
495 const auto a = wrapper::vloadq(input1_ptr + x);
496 const auto b = wrapper::vloadq(input2_ptr + x);
497 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
498 for(int i = 0; i < 4; i++)
499 {
500 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
501 }
502 x = +4;
503 }
504 return x;
505}
506
507template <ComparisonOperation op>
508inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
509 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
510 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
511 float32x4_t voffseto, float32x4_t invvscaleo)
512{
513 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
514 int x = window_start_x;
515 for(; x <= (window_end_x - window_step_x); x += window_step_x)
516 {
517 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
518 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
519 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
520 store_quantized(output_ptr + x, rf);
521 }
522 return x;
523}
524
morgolock74a16962020-01-15 11:40:49 +0000525template <ComparisonOperation op>
526inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
527 const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
528 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
529 float32x4_t voffseto, float32x4_t invvscaleo)
530{
531 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
532 int x = window_start_x;
533 for(; x <= (window_end_x - window_step_x); x += window_step_x)
534 {
535 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
536 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
537 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
538 store_quantized(output_ptr + x, rf);
539 }
540 return x;
541}
542
George Wortd88590f2018-12-12 17:39:58 +0000543template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
Michele Di Giorgio1c76c1d2020-08-28 13:25:31 +0100544inline int elementwise_comp_op_broadcast_8_loop(int window_start_x, int window_end_x, int window_step_x,
545 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
546{
547 int x = window_start_x;
548 for(; x <= (window_end_x - window_step_x); x += window_step_x)
549 {
550 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
551 wrapper::vstore(output_ptr + x, a);
552 }
553 return x;
554}
555
556template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
George Wortd88590f2018-12-12 17:39:58 +0000557inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
558 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
559{
560 int x = window_start_x;
561 for(; x <= (window_end_x - window_step_x); x += window_step_x)
562 {
563 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
564 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
565 }
566 return x;
567}
568
569template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
570inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
571 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
572{
573 int x = window_start_x;
574 for(; x <= (window_end_x - window_step_x); x += window_step_x)
575 {
576 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
577 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
578 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
579 }
580 if(x <= window_end_x - 4)
581 {
582 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
583 for(int i = 0; i < 4; i++)
584 {
585 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
586 }
587 x = +4;
588 }
589 return x;
590}
591
592template <ComparisonOperation op>
593inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
594 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
595 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
596 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
597{
598 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
599 int x = window_start_x;
600 for(; x <= (window_end_x - window_step_x); x += window_step_x)
601 {
602 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
603 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
604 store_quantized(output_ptr + x, rf);
605 }
606 return x;
607}
608
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100609template <ComparisonOperation op>
610inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
611 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
612 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
613 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
614{
615 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
616 int x = window_start_x;
617 for(; x <= (window_end_x - window_step_x); x += window_step_x)
618 {
619 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
620 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
621 store_quantized(output_ptr + x, rf);
622 }
623 return x;
624}
625
George Wortd88590f2018-12-12 17:39:58 +0000626template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
627void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
628 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
629 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
630 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
giuros0192fd9432018-12-03 17:30:00 +0000631{
632 // Create input windows
633 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
634 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
635
636 // Clear X Dimension on execution window as we handle manually
637 Window win = window;
638 win.set(Window::DimX, Window::Dimension(0, 1, 1));
639
Michalis Spyroue8c0c432019-01-22 11:08:31 +0000640 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
giuros0192fd9432018-12-03 17:30:00 +0000641 const auto window_start_x = static_cast<int>(window.x().start());
642 const auto window_end_x = static_cast<int>(window.x().end());
643 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
644
645 if(is_broadcast_across_x)
646 {
giuros0192fd9432018-12-03 17:30:00 +0000647 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
648 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
649 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
650 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
651 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
652
653 // Clear X Dimension on execution window as we handle manually
654 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
655
656 Iterator broadcast_input(broadcast_tensor, broadcast_win);
657 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
658 Iterator output(out, win);
659
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100660 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000661 {
George Wortd88590f2018-12-12 17:39:58 +0000662 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
663 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
664 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000665
George Wortd88590f2018-12-12 17:39:58 +0000666 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000667 for(; x < window_end_x; ++x)
668 {
669 const auto a = *(non_broadcast_input_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000670 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
giuros0192fd9432018-12-03 17:30:00 +0000671 }
672 },
673 broadcast_input, non_broadcast_input, output);
674 }
675 else
676 {
677 // Clear X Dimension on execution window as we handle manually
678 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
679 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
680
681 Iterator input1(in1, input1_win);
682 Iterator input2(in2, input2_win);
683 Iterator output(out, win);
684
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100685 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000686 {
George Wortd88590f2018-12-12 17:39:58 +0000687 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
688 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
689 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000690
George Wortd88590f2018-12-12 17:39:58 +0000691 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
giuros0192fd9432018-12-03 17:30:00 +0000692 for(; x < window_end_x; ++x)
693 {
694 const auto a = *(input1_ptr + x);
695 const auto b = *(input2_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000696 *(output_ptr + x) = (*scalar_func)(a, b);
giuros0192fd9432018-12-03 17:30:00 +0000697 }
giuros0192fd9432018-12-03 17:30:00 +0000698 },
699 input1, input2, output);
700 }
701}
702
George Wortd88590f2018-12-12 17:39:58 +0000703void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100704 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
George Wortd88590f2018-12-12 17:39:58 +0000705 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
706 float32x4_t, float32x4_t, const bool),
707 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
708 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
709 float32x4_t, float32x4_t))
giuros0192fd9432018-12-03 17:30:00 +0000710{
711 // Create input windows
712 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
713 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
714
715 // Clear X Dimension on execution window as we handle manually
716 Window win = window;
717 win.set(Window::DimX, Window::Dimension(0, 1, 1));
718
719 const int window_step_x = 16;
720 const auto window_start_x = static_cast<int>(window.x().start());
721 const auto window_end_x = static_cast<int>(window.x().end());
722 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
723
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100724 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000725
726 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100727 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
728 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000729
730 if(is_broadcast_across_x)
731 {
732 // Select the broadcast input on the X axis
733 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
734 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
735 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
736 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
737 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
738
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100739 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
740 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000741
742 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
743 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
744
745 // Clear X Dimension on execution window as we handle manually
746 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
747
748 Iterator broadcast_input(broadcast_tensor, broadcast_win);
749 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
750 Iterator output(out, win);
751
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100752 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000753 {
754 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
755 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
756
757 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100758 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000759
George Wortd88590f2018-12-12 17:39:58 +0000760 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
761 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000762 for(; x < window_end_x; ++x)
763 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100764 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
765 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
766 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000767 }
768 },
769 broadcast_input, non_broadcast_input, output);
770 }
771 else
772 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100773 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
774 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
775
giuros0192fd9432018-12-03 17:30:00 +0000776 // Input1 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100777 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
778 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000779
780 // Input2 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100781 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
782 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000783
784 // Clear X Dimension on execution window as we handle manually
785 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
786 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
787
giuros0192fd9432018-12-03 17:30:00 +0000788 Iterator input1(in1, input1_win);
789 Iterator input2(in2, input2_win);
790 Iterator output(out, win);
791
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100792 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000793 {
794 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
795 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
796 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
797
George Wortd88590f2018-12-12 17:39:58 +0000798 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
799 vscale1, vscale2, voffseto, invvscaleo);
giuros0192fd9432018-12-03 17:30:00 +0000800 for(; x < window_end_x; ++x)
801 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100802 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
803 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
804 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000805 }
806 },
807 input1, input2, output);
808 }
809}
810
morgolock74a16962020-01-15 11:40:49 +0000811void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
812 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100813 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
814 float32x4_t, float32x4_t, const bool),
morgolock74a16962020-01-15 11:40:49 +0000815 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
816 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
817 float32x4_t, float32x4_t))
818{
819 // Create input windows
820 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
821 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
822
823 // Clear X Dimension on execution window as we handle manually
824 Window win = window;
825 win.set(Window::DimX, Window::Dimension(0, 1, 1));
826
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100827 const int window_step_x = 16;
828 const auto window_start_x = static_cast<int>(window.x().start());
829 const auto window_end_x = static_cast<int>(window.x().end());
830 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
831
832 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
morgolock74a16962020-01-15 11:40:49 +0000833
834 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
835 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100836
837 if(is_broadcast_across_x)
838 {
839 // Select the broadcast input on the X axis
840 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
841 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
842 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
843 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
844 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
845
846 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
847 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
848
849 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
850 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
851
852 // Clear X Dimension on execution window as we handle manually
853 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
854
855 Iterator broadcast_input(broadcast_tensor, broadcast_win);
856 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
857 Iterator output(out, win);
858
859 execute_window_loop(win, [&](const Coordinates &)
860 {
861 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
862 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
863
864 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
865 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
866
867 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
868 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
869 for(; x < window_end_x; ++x)
870 {
871 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
872 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
873 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
874 }
875 },
876 broadcast_input, non_broadcast_input, output);
877 }
878 else
morgolock74a16962020-01-15 11:40:49 +0000879 {
880 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
881 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
882
883 // Input1 quantization info
884 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
885 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
886
887 // Input2 quantization info
888 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
889 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
890
891 // Clear X Dimension on execution window as we handle manually
892 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
893 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
894
895 Iterator input1(in1, input1_win);
896 Iterator input2(in2, input2_win);
897 Iterator output(out, win);
898
899 execute_window_loop(win, [&](const Coordinates &)
900 {
901 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
902 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
903 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
904
905 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
906 vscale1, vscale2, voffseto, invvscaleo);
907 for(; x < window_end_x; ++x)
908 {
909 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
910 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
911 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
912 }
913 },
914 input1, input2, output);
915 }
916}
917
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000918void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
919 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
920 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
921 float32x4_t, float32x4_t, const bool),
922 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
923 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
924 float32x4_t, float32x4_t))
925{
926 // Create input windows
927 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
928 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
929
930 // Clear X Dimension on execution window as we handle manually
931 Window win = window;
932 win.set(Window::DimX, Window::Dimension(0, 1, 1));
933
934 const int window_step_x = 16;
935 const auto window_start_x = static_cast<int>(window.x().start());
936 const auto window_end_x = static_cast<int>(window.x().end());
937 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
938
939 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
940
morgolocka3598052019-12-31 12:20:47 +0000941 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000942 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
943
944 if(is_broadcast_across_x)
945 {
946 // Select the broadcast input on the X axis
947 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
948 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
949 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
950 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
951 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
952
953 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
954 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
955
956 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
957 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
958
959 // Clear X Dimension on execution window as we handle manually
960 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
961
962 Iterator broadcast_input(broadcast_tensor, broadcast_win);
963 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
964 Iterator output(out, win);
965
966 execute_window_loop(win, [&](const Coordinates &)
967 {
968 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
969 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
970
971 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
Sheri Zhang5eaf57c2020-05-04 21:38:17 +0100972 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000973
974 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
975 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
976 for(; x < window_end_x; ++x)
977 {
978 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
979 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
980 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
981 }
982 },
983 broadcast_input, non_broadcast_input, output);
984 }
985 else
986 {
987 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
988 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
989
990 // Input1 quantization info
991 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
992 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
993
994 // Input2 quantization info
995 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
996 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
997
998 // Clear X Dimension on execution window as we handle manually
999 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1000 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1001
1002 Iterator input1(in1, input1_win);
1003 Iterator input2(in2, input2_win);
1004 Iterator output(out, win);
1005
1006 execute_window_loop(win, [&](const Coordinates &)
1007 {
1008 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
1009 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
1010 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1011
1012 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
1013 vscale1, vscale2, voffseto, invvscaleo);
1014 for(; x < window_end_x; ++x)
1015 {
1016 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
1017 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
1018 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
1019 }
1020 },
1021 input1, input2, output);
1022 }
1023}
1024
George Wortd88590f2018-12-12 17:39:58 +00001025template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
Michele Di Giorgio1c76c1d2020-08-28 13:25:31 +01001026void elementwise_comp_op_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1027{
1028 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
1029 &elementwise_comp_op_scalar<op, InputScalarType>,
1030 &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>,
1031 &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>);
1032}
1033
1034template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
George Wortd88590f2018-12-12 17:39:58 +00001035void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
giuros0192fd9432018-12-03 17:30:00 +00001036{
George Wortd88590f2018-12-12 17:39:58 +00001037 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
1038 &elementwise_comp_op_scalar<op, InputScalarType>,
1039 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
1040 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
1041}
1042
1043template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
1044void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1045{
1046 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
1047 &elementwise_comp_op_scalar<op, InputScalarType>,
1048 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
1049 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
1050}
1051
giuros01d5134362019-05-14 16:12:53 +01001052template <ArithmeticOperation op, typename VectorType>
George Wortd88590f2018-12-12 17:39:58 +00001053void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1054{
giuros01d5134362019-05-14 16:12:53 +01001055 using scalar_type = typename VectorType::scalar_type;
1056
1057 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
1058 &elementwise_arithm_op_scalar<op, scalar_type>,
1059 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
1060 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
George Wortd88590f2018-12-12 17:39:58 +00001061}
1062
1063template <ArithmeticOperation op>
1064void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1065{
1066 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
1067 &elementwise_arithm_op_quantized_broadcast_loop<op>,
1068 &elementwise_arithm_op_quantized_loop<op>);
1069}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001070template <ArithmeticOperation op>
1071void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1072{
1073 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
1074 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1075 &elementwise_arithm_op_quantized_singed_loop<op>);
1076}
George Wortd88590f2018-12-12 17:39:58 +00001077
1078template <ComparisonOperation op>
1079void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1080{
1081 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1082 &elementwise_comp_op_quantized_broadcast_loop<op>,
1083 &elementwise_comp_op_quantized_loop<op>);
1084}
1085
morgolock74a16962020-01-15 11:40:49 +00001086template <ComparisonOperation op>
1087void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1088{
Michele Di Giorgio81870c02020-04-30 12:02:20 +01001089 elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1090 &elementwise_comp_op_quantized_signed_broadcast_loop<op>,
1091 &elementwise_comp_op_quantized_signed_loop<op>);
morgolock74a16962020-01-15 11:40:49 +00001092}
1093
George Wortd88590f2018-12-12 17:39:58 +00001094std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001095configure_func(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output,
George Wortd88590f2018-12-12 17:39:58 +00001096 std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function)
1097{
1098 std::string function_to_call("op_");
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001099 function_to_call += string_from_data_type(input1->data_type()) + "_";
1100 function_to_call += string_from_data_type(input2->data_type()) + "_";
1101 function_to_call += string_from_data_type(output->data_type());
George Wortd88590f2018-12-12 17:39:58 +00001102
1103 auto it = map_function.find(function_to_call);
1104
1105 if(it != map_function.end())
1106 {
1107 auto func = it->second;
1108 return [func](const ITensor * input1, const ITensor * input2, ITensor * output, const Window & window)
1109 {
1110 func(input1, input2, output, window);
1111 };
1112 }
1113 return nullptr;
1114}
1115
1116template <ArithmeticOperation op>
1117std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001118configure_arithm_func(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
George Wortd88590f2018-12-12 17:39:58 +00001119{
1120 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1121 {
giuros01d5134362019-05-14 16:12:53 +01001122 { "op_F32_F32_F32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float, 4>> },
1123 { "op_S16_S16_S16", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int16_t, 8>> },
1124 { "op_S32_S32_S32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int32_t, 4>> },
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001125 { "op_QASYMM8_QASYMM8_QASYMM8", &elementwise_arithm_op_quantized<op> },
1126 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED", &elementwise_arithm_op_quantized_signed<op> }
George Wortd88590f2018-12-12 17:39:58 +00001127 };
1128#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
giuros01d5134362019-05-14 16:12:53 +01001129 map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float16_t, 8>>;
George Wortd88590f2018-12-12 17:39:58 +00001130#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1131
1132 return configure_func(input1, input2, output, map_function);
1133}
1134
1135template <ComparisonOperation op>
1136std::function<void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)>
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001137configure_comp_func(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
George Wortd88590f2018-12-12 17:39:58 +00001138{
1139 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1140 {
Michele Di Giorgio1c76c1d2020-08-28 13:25:31 +01001141 { "op_U8_U8_U8", &elementwise_comp_op_8<op, uint8_t, uint8x16_t> },
George Wortd88590f2018-12-12 17:39:58 +00001142 { "op_F32_F32_U8", &elementwise_comp_op_32<op, float, float32x4_t> },
1143 { "op_S16_S16_U8", &elementwise_comp_op_16<op, int16_t, int16x8_t> },
1144 { "op_S32_S32_U8", &elementwise_comp_op_32<op, int32_t, int32x4_t> },
morgolock74a16962020-01-15 11:40:49 +00001145 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_U8", &elementwise_comp_op_quantized_signed<op> },
George Wortd88590f2018-12-12 17:39:58 +00001146 { "op_QASYMM8_QASYMM8_U8", &elementwise_comp_op_quantized<op> }
1147 };
1148#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1149 map_function["op_F16_F16_U8"] = &elementwise_comp_op_16<op, float16_t, float16x8_t>;
1150#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1151
1152 return configure_func(input1, input2, output, map_function);
1153}
1154} // namespace
1155
1156NEElementwiseOperationKernel::NEElementwiseOperationKernel()
1157 : _function(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
1158{
1159}
1160
1161Status NEElementwiseOperationKernel::validate_arguments_common(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1162{
George Wortd88590f2018-12-12 17:39:58 +00001163 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
giuros0192fd9432018-12-03 17:30:00 +00001164 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
1165
1166 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
1167
1168 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1169
1170 // Validate in case of configured output
1171 if(output.total_size() > 0)
1172 {
giuros0192fd9432018-12-03 17:30:00 +00001173 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
1174 "Wrong shape for output");
1175 }
1176
1177 return Status{};
1178}
giuros0192fd9432018-12-03 17:30:00 +00001179
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001180void NEElementwiseOperationKernel::configure_common(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
giuros0192fd9432018-12-03 17:30:00 +00001181{
1182 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001183
1184 // Configure kernel window
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001185 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
giuros0192fd9432018-12-03 17:30:00 +00001186 const TensorShape &out_shape = broadcast_pair.first;
1187 const ValidRegion &valid_region = broadcast_pair.second;
1188
1189 // Auto initialize output if not initialized
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001190 auto_init_if_empty(*output, out_shape, 1, input1->data_type());
giuros0192fd9432018-12-03 17:30:00 +00001191
1192 Window win = calculate_max_window(valid_region);
1193
giuros0192fd9432018-12-03 17:30:00 +00001194 INEKernel::configure(win);
1195}
1196
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001197void NEElementwiseOperationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
giuros0192fd9432018-12-03 17:30:00 +00001198{
George Wortd88590f2018-12-12 17:39:58 +00001199 ARM_COMPUTE_UNUSED(info, window);
giuros0192fd9432018-12-03 17:30:00 +00001200 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1201 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
George Wortd88590f2018-12-12 17:39:58 +00001202 ARM_COMPUTE_ERROR_ON(_function == nullptr);
Georgios Pinitas0499dff2020-07-31 22:21:38 +01001203 _function(tensors.get_const_tensor(TensorType::ACL_SRC_0),
1204 tensors.get_const_tensor(TensorType::ACL_SRC_1),
1205 tensors.get_tensor(TensorType::ACL_DST), window);
giuros0192fd9432018-12-03 17:30:00 +00001206}
1207
1208/** Arithmetic operators (min, max, squared_diff) */
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001209void NEArithmeticOperationKernel::configure(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
giuros0192fd9432018-12-03 17:30:00 +00001210{
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001211 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
George Wortd88590f2018-12-12 17:39:58 +00001212 configure_common(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001213 switch(op)
1214 {
1215 case ArithmeticOperation::MAX:
George Wortd88590f2018-12-12 17:39:58 +00001216 _function = configure_arithm_func<ArithmeticOperation::MAX>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001217 break;
1218 case ArithmeticOperation::MIN:
George Wortd88590f2018-12-12 17:39:58 +00001219 _function = configure_arithm_func<ArithmeticOperation::MIN>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001220 break;
1221 case ArithmeticOperation::SQUARED_DIFF:
George Wortd88590f2018-12-12 17:39:58 +00001222 _function = configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001223 break;
giuros01d5134362019-05-14 16:12:53 +01001224 case ArithmeticOperation::PRELU:
1225 _function = configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
1226 break;
giuros0192fd9432018-12-03 17:30:00 +00001227 default:
1228 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1229 }
1230}
1231
George Wortd88590f2018-12-12 17:39:58 +00001232Status NEArithmeticOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1233{
Michele Di Giorgio1c76c1d2020-08-28 13:25:31 +01001234 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
George Wortd88590f2018-12-12 17:39:58 +00001235 // Validate in case of configured output
1236 if(output.total_size() > 0)
1237 {
1238 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
1239 }
1240 return validate_arguments_common(input1, input2, output);
1241}
1242
giuros0192fd9432018-12-03 17:30:00 +00001243Status NEArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1244{
1245 ARM_COMPUTE_UNUSED(op);
1246 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
George Wortd88590f2018-12-12 17:39:58 +00001247 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
giuros0192fd9432018-12-03 17:30:00 +00001248 return Status{};
1249}
1250
George Worta1e7e282019-01-15 11:00:29 +00001251/** The division operator */
1252
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001253void NEDivisionOperationKernel::configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
George Worta1e7e282019-01-15 11:00:29 +00001254{
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001255 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
George Worta1e7e282019-01-15 11:00:29 +00001256 configure_common(input1, input2, output);
1257 _function = configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output);
1258}
1259
1260Status NEDivisionOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1261{
1262 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1263 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1264}
1265
1266Status NEDivisionOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1267{
1268 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1269 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1270 return Status{};
1271}
1272
Usama Arif81e671e2019-05-13 13:33:14 +01001273/** The power operator */
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001274void NEPowerOperationKernel::configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
Usama Arif81e671e2019-05-13 13:33:14 +01001275{
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001276 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
Usama Arif81e671e2019-05-13 13:33:14 +01001277 configure_common(input1, input2, output);
1278 _function = configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
1279}
1280
1281Status NEPowerOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1282{
1283 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1284 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1285}
1286
1287Status NEPowerOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1288{
1289 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1290 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1291 return Status{};
1292}
1293
George Wortd88590f2018-12-12 17:39:58 +00001294/** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001295void NEComparisonOperationKernel::configure(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
giuros0192fd9432018-12-03 17:30:00 +00001296{
Michalis Spyrouce0c6752020-06-18 10:14:57 +01001297 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
George Wortd88590f2018-12-12 17:39:58 +00001298 configure_common(input1, input2, output);
1299 switch(op)
1300 {
1301 case ComparisonOperation::Equal:
1302 _function = configure_comp_func<ComparisonOperation::Equal>(input1, input2, output);
1303 break;
1304 case ComparisonOperation::NotEqual:
1305 _function = configure_comp_func<ComparisonOperation::NotEqual>(input1, input2, output);
1306 break;
1307 case ComparisonOperation::Greater:
1308 _function = configure_comp_func<ComparisonOperation::Greater>(input1, input2, output);
1309 break;
1310 case ComparisonOperation::GreaterEqual:
1311 _function = configure_comp_func<ComparisonOperation::GreaterEqual>(input1, input2, output);
1312 break;
1313 case ComparisonOperation::Less:
1314 _function = configure_comp_func<ComparisonOperation::Less>(input1, input2, output);
1315 break;
1316 case ComparisonOperation::LessEqual:
1317 _function = configure_comp_func<ComparisonOperation::LessEqual>(input1, input2, output);
1318 break;
1319 default:
1320 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1321 }
1322}
1323
1324Status NEComparisonOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1325{
Michele Di Giorgio1c76c1d2020-08-28 13:25:31 +01001326 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
George Wortd88590f2018-12-12 17:39:58 +00001327 // Validate in case of configured output
1328 if(output.total_size() > 0)
1329 {
1330 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8);
1331 }
1332 return validate_arguments_common(input1, input2, output);
1333}
1334
1335Status NEComparisonOperationKernel::validate(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1336{
1337 ARM_COMPUTE_UNUSED(op);
1338 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1339 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1340 return Status{};
giuros0192fd9432018-12-03 17:30:00 +00001341}
1342} // namespace arm_compute