blob: ead54ab14eac8ca9c4a4395fad37aadf53265520 [file] [log] [blame]
Dana Zlotnikd5c496d2021-11-28 14:46:12 +02001/*
2 * Copyright (c) 2021-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
25#define SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
26
27#include "src/core/NEON/NEAsymm.h"
28
29namespace arm_compute
30{
31namespace cpu
32{
33template <ArithmeticOperation op, typename VectorType>
34typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
35{
36 using vec_type = typename VectorType::type;
37 using scalar_type = typename VectorType::scalar_type;
38 using tag_type = typename VectorType::tag_type;
39
40 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
41
42 switch(op)
43 {
44 case ArithmeticOperation::MAX:
45 res = wrapper::vmax(a, b);
46 break;
47 case ArithmeticOperation::MIN:
48 res = wrapper::vmin(a, b);
49 break;
50 case ArithmeticOperation::SQUARED_DIFF:
51 {
52 const vec_type tmp = wrapper::vsub(a, b);
53 res = wrapper::vmul(tmp, tmp);
54 break;
55 }
56 case ArithmeticOperation::PRELU:
57 {
58 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
59 const vec_type tmp = wrapper::vmul(a, b);
60 const auto gt = wrapper::vcgt(a, zero);
61
62 res = wrapper::vbsl(gt, a, tmp);
63 break;
64 }
65
66 default:
67 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
68 }
69
70 return res;
71}
72template <ArithmeticOperation op, typename ScalarType, typename VectorType>
73typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
74{
75 using tag_type = typename VectorType::tag_type;
76 using vec_type = typename VectorType::type;
77
78 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
79 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
80}
81
82template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
83void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
84 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
85 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
86 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
87{
88 // Create input windows
89 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
90 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
91
92 // Clear X Dimension on execution window as we handle manually
93 Window win = window;
94 win.set(Window::DimX, Window::Dimension(0, 1, 1));
95
96 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
97 const auto window_start_x = static_cast<int>(window.x().start());
98 const auto window_end_x = static_cast<int>(window.x().end());
99 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
100
101 if(is_broadcast_across_x)
102 {
103 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
104 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
105 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
106 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
107 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
108
109 // Clear X Dimension on execution window as we handle manually
110 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
111
112 Iterator broadcast_input(broadcast_tensor, broadcast_win);
113 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
114 Iterator output(out, win);
115
116 execute_window_loop(win, [&](const Coordinates &)
117 {
118 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
119 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
120 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
121
122 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
123 for(; x < window_end_x; ++x)
124 {
125 const auto a = *(non_broadcast_input_ptr + x);
126 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
127 }
128 },
129 broadcast_input, non_broadcast_input, output);
130 }
131 else
132 {
133 // Clear X Dimension on execution window as we handle manually
134 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
135 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
136
137 Iterator input1(in1, input1_win);
138 Iterator input2(in2, input2_win);
139 Iterator output(out, win);
140
141 execute_window_loop(win, [&](const Coordinates &)
142 {
143 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
144 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
145 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
146
147 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
148 for(; x < window_end_x; ++x)
149 {
150 const auto a = *(input1_ptr + x);
151 const auto b = *(input2_ptr + x);
152 *(output_ptr + x) = (*scalar_func)(a, b);
153 }
154 },
155 input1, input2, output);
156 }
157}
158
159template <ArithmeticOperation op, typename ScalarType>
160inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
161{
162 auto res = ScalarType(0);
163
164 switch(op)
165 {
166 case ArithmeticOperation::MAX:
167 res = std::max(a, b);
168 break;
169 case ArithmeticOperation::MIN:
170 res = std::min(a, b);
171 break;
172 case ArithmeticOperation::SQUARED_DIFF:
173 {
174 res = (a - b) * (a - b);
175 break;
176 }
177 case ArithmeticOperation::PRELU:
178 {
179 res = (a > 0 ? a : a * b);
180 break;
181 }
182 case ArithmeticOperation::DIV:
183 {
184 res = a / b;
185 if(std::is_integral<ScalarType>::value)
186 {
187 res = (b == 0) ? 0 : res;
188 if(static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0)))
189 {
190 --res;
191 }
192 }
193 break;
194 }
195 case ArithmeticOperation::POWER:
196 {
197 res = std::pow(a, b);
198 break;
199 }
200 default:
201 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
202 }
203 return res;
204}
205
206template <>
207inline int32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a, const int32x4_t &b)
208{
209 return vcvtq_s32_f32(vfloorq_f32(wrapper::vdiv(vcvtq_f32_s32(a), vcvtq_f32_s32(b))));
210}
211
212template <>
213inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
214{
215 return wrapper::vdiv(a, b);
216}
217
218template <>
219inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
220{
221 return wrapper::vpow(a, b);
222}
223
224#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
225template <>
226inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
227{
228 return wrapper::vdiv(a, b);
229}
230
231template <>
232inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
233{
234 return wrapper::vpow(a, b);
235}
236#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
237
238template <ArithmeticOperation op, typename ScalarType, typename VectorType>
239inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
240 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
241{
242 int x = window_start_x;
243 for(; x <= (window_end_x - window_step_x); x += window_step_x)
244 {
245 const auto a = wrapper::vloadq(input1_ptr + x);
246 const auto b = wrapper::vloadq(input2_ptr + x);
247 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
248 }
249 return x;
250}
251
252template <ArithmeticOperation op, typename ScalarType, typename VectorType>
253inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
254 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
255{
256 int x = window_start_x;
257 for(; x <= (window_end_x - window_step_x); x += window_step_x)
258 {
259 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
260 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
261 }
262 return x;
263}
264
265template <ArithmeticOperation op, typename VectorType>
266void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
267{
268 using scalar_type = typename VectorType::scalar_type;
269
270 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
271 &elementwise_arithm_op_scalar<op, scalar_type>,
272 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
273 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
274}
275
276template <ComparisonOperation op, typename InputScalarType>
277inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
278{
279 bool res = false;
280
281 switch(op)
282 {
283 case ComparisonOperation::Equal:
284 res = (a == b);
285 break;
286 case ComparisonOperation::NotEqual:
287 res = (a != b);
288 break;
289 case ComparisonOperation::Greater:
290 res = (a > b);
291 break;
292 case ComparisonOperation::GreaterEqual:
293 res = (a >= b);
294 break;
295 case ComparisonOperation::Less:
296 res = (a < b);
297 break;
298 case ComparisonOperation::LessEqual:
299 res = (a <= b);
300 break;
301 default:
302 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
303 }
304 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
305}
306
307template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
308inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
309{
310 OutputVectorType res = { 0, 0, 0, 0 };
311
312 switch(op)
313 {
314 case ComparisonOperation::Equal:
315 res = wrapper::vceq(a, b);
316 break;
317 case ComparisonOperation::NotEqual:
318 res = wrapper::vnot(wrapper::vceq(a, b));
319 break;
320 case ComparisonOperation::Greater:
321 res = wrapper::vcgt(a, b);
322 break;
323 case ComparisonOperation::GreaterEqual:
324 res = wrapper::vcge(a, b);
325 break;
326 case ComparisonOperation::Less:
327 res = wrapper::vcgt(b, a);
328 break;
329 case ComparisonOperation::LessEqual:
330 res = wrapper::vcge(b, a);
331 break;
332 default:
333 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
334 }
335
336 return res;
337}
338
339template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
340inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
341{
342 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
343 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
344}
345
346template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
347inline int elementwise_comp_op_broadcast_8_loop(int window_start_x, int window_end_x, int window_step_x,
348 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
349{
350 int x = window_start_x;
351 for(; x <= (window_end_x - window_step_x); x += window_step_x)
352 {
353 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
354 wrapper::vstore(output_ptr + x, a);
355 }
356 return x;
357}
358
359template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
360inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
361 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
362{
363 int x = window_start_x;
364 for(; x <= (window_end_x - window_step_x); x += window_step_x)
365 {
366 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
367 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
368 }
369 return x;
370}
371
372template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
373inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
374 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
375{
376 int x = window_start_x;
377 for(; x <= (window_end_x - window_step_x); x += window_step_x)
378 {
379 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
380 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
381 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
382 }
383 if(x <= window_end_x - 4)
384 {
385 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
386 for(int i = 0; i < 4; i++)
387 {
388 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
389 }
390 x = +4;
391 }
392 return x;
393}
394
395template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
396inline int elementwise_comp_op_8_loop(int window_start_x, int window_end_x, int window_step_x,
397 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
398{
399 int x = window_start_x;
400 for(; x <= (window_end_x - window_step_x); x += window_step_x)
401 {
402 const auto a = wrapper::vloadq(input1_ptr + x);
403 const auto b = wrapper::vloadq(input2_ptr + x);
404 const auto res = elementwise_comp_op<op, InputVectorType, uint8x16_t>(a, b);
405 wrapper::vstore(output_ptr + x, res);
406 }
407 return x;
408}
409
410template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
411inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
412 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
413{
414 int x = window_start_x;
415 for(; x <= (window_end_x - window_step_x); x += window_step_x)
416 {
417 const auto a = wrapper::vloadq(input1_ptr + x);
418 const auto b = wrapper::vloadq(input2_ptr + x);
419 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
420 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
421 }
422 return x;
423}
424
425template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
426inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
427 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
428{
429 int x = window_start_x;
430 for(; x <= (window_end_x - window_step_x); x += window_step_x)
431 {
432 auto a = wrapper::vloadq(input1_ptr + x);
433 auto b = wrapper::vloadq(input2_ptr + x);
434 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
435 a = wrapper::vloadq(input1_ptr + x + 4);
436 b = wrapper::vloadq(input2_ptr + x + 4);
437 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
438 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
439 }
440 if(x <= window_end_x - 4)
441 {
442 const auto a = wrapper::vloadq(input1_ptr + x);
443 const auto b = wrapper::vloadq(input2_ptr + x);
444 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
445 for(int i = 0; i < 4; i++)
446 {
447 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
448 }
449 x = +4;
450 }
451 return x;
452}
453
454template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
455void elementwise_comp_op_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
456{
457 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
458 &elementwise_comp_op_scalar<op, InputScalarType>,
459 &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>,
460 &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>);
461}
462
463template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
464void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
465{
466 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
467 &elementwise_comp_op_scalar<op, InputScalarType>,
468 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
469 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
470}
471
472template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
473void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
474{
475 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
476 &elementwise_comp_op_scalar<op, InputScalarType>,
477 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
478 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
479}
480
481inline float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
482{
483 qasymm8x16_t x = vld1q_u8(input1_ptr);
484 const float32x4x4_t out =
485 {
486 {
487 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
488 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
489 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
490 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
491 }
492 };
493 return out;
494}
495
496inline float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
497{
498 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
499 const float32x4x4_t out =
500 {
501 {
502 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
503 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
504 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
505 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
506 }
507 };
508 return out;
509}
510
511inline void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
512{
513 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
514 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
515 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
516}
517
518inline void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
519{
520 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
521 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
522 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
523}
524
525inline void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
526{
527 int32x4x4_t out =
528 {
529 {
530 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
531 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
532 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
533 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
534 }
535 };
536 store_quantized(output_ptr, out);
537}
538
539inline void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
540{
541 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
542 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
543 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
544}
545
546inline void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
547{
548 int32x4x4_t out =
549 {
550 {
551 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
552 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
553 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
554 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
555 }
556 };
557 store_quantized_signed(output_ptr, out);
558}
559
560template <ArithmeticOperation op>
561inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
562{
563 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
564}
565
566template <ArithmeticOperation op>
567inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
568{
569 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
570}
571
572template <ArithmeticOperation op>
573float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
574{
575 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
576 float32x4x4_t out =
577 {
578 {
579 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
580 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
581 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
582 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
583 }
584 };
585 return out;
586}
587
588template <ComparisonOperation op>
589inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
590{
591 ARM_COMPUTE_UNUSED(qinfo);
592 return elementwise_comp_op_scalar<op>(a, b);
593}
594
595template <ComparisonOperation op>
596inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
597{
598 uint32x4x4_t out =
599 {
600 {
601 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
602 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
603 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
604 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
605 }
606 };
607 return out;
608}
609
610template <ArithmeticOperation op>
611inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
612 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
613 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
614 float32x4_t voffseto, float32x4_t invvscaleo)
615{
616 int x = window_start_x;
617 for(; x <= (window_end_x - window_step_x); x += window_step_x)
618 {
619 // Get inputs and compute output
620 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
621 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
622 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
623 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
624 }
625 return x;
626}
627
628template <ArithmeticOperation op>
629inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
630 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
631 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
632 float32x4_t voffseto, float32x4_t invvscaleo)
633{
634 int x = window_start_x;
635 for(; x <= (window_end_x - window_step_x); x += window_step_x)
636 {
637 // Get inputs and compute output
638 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
639 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
640 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
641 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
642 }
643 return x;
644}
645
646template <ArithmeticOperation op>
647inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
648 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
649 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
650 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
651{
652 int x = window_start_x;
653 for(; x <= (window_end_x - window_step_x); x += window_step_x)
654 {
655 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
656 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
657 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
658 }
659 return x;
660}
661template <ArithmeticOperation op>
662inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
663 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
664 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
665 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
666{
667 int x = window_start_x;
668 for(; x <= (window_end_x - window_step_x); x += window_step_x)
669 {
670 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
671 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
672 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
673 }
674 return x;
675}
676
677template <ComparisonOperation op>
678inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
679 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
680 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
681 float32x4_t voffseto, float32x4_t invvscaleo)
682{
683 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
684 int x = window_start_x;
685 for(; x <= (window_end_x - window_step_x); x += window_step_x)
686 {
687 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
688 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
689 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
690 store_quantized(output_ptr + x, rf);
691 }
692 return x;
693}
694
695template <ComparisonOperation op>
696inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
697 const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
698 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
699 float32x4_t voffseto, float32x4_t invvscaleo)
700{
701 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
702 int x = window_start_x;
703 for(; x <= (window_end_x - window_step_x); x += window_step_x)
704 {
705 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
706 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
707 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
708 store_quantized(output_ptr + x, rf);
709 }
710 return x;
711}
712
713template <ComparisonOperation op>
714inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
715 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
716 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
717 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
718{
719 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
720 int x = window_start_x;
721 for(; x <= (window_end_x - window_step_x); x += window_step_x)
722 {
723 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
724 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
725 store_quantized(output_ptr + x, rf);
726 }
727 return x;
728}
729
730template <ComparisonOperation op>
731inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
732 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
733 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
734 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
735{
736 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
737 int x = window_start_x;
738 for(; x <= (window_end_x - window_step_x); x += window_step_x)
739 {
740 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
741 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
742 store_quantized(output_ptr + x, rf);
743 }
744 return x;
745}
746
747inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
748 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
749 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
750 float32x4_t, float32x4_t, const bool),
751 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
752 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
753 float32x4_t, float32x4_t))
754{
755 // Create input windows
756 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
757 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
758
759 // Clear X Dimension on execution window as we handle manually
760 Window win = window;
761 win.set(Window::DimX, Window::Dimension(0, 1, 1));
762
763 const int window_step_x = 16;
764 const auto window_start_x = static_cast<int>(window.x().start());
765 const auto window_end_x = static_cast<int>(window.x().end());
766 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
767
768 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
769
770 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
771 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
772 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
773
774 if(is_broadcast_across_x)
775 {
776 // Select the broadcast input on the X axis
777 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
778 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
779 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
780 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
781 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
782
783 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
784 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
785
786 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
787 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
788
789 // Clear X Dimension on execution window as we handle manually
790 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
791
792 Iterator broadcast_input(broadcast_tensor, broadcast_win);
793 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
794 Iterator output(out, win);
795
796 execute_window_loop(win, [&](const Coordinates &)
797 {
798 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
799 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
800
801 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
802 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
803
804 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
805 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
806 for(; x < window_end_x; ++x)
807 {
808 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
809 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
810 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
811 }
812 },
813 broadcast_input, non_broadcast_input, output);
814 }
815 else
816 {
817 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
818 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
819
820 // Input1 quantization info
821 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
822 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
823
824 // Input2 quantization info
825 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
826 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
827
828 // Clear X Dimension on execution window as we handle manually
829 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
830 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
831
832 Iterator input1(in1, input1_win);
833 Iterator input2(in2, input2_win);
834 Iterator output(out, win);
835
836 execute_window_loop(win, [&](const Coordinates &)
837 {
838 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
839 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
840 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
841
842 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
843 vscale1, vscale2, voffseto, invvscaleo);
844 for(; x < window_end_x; ++x)
845 {
846 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
847 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
848 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
849 }
850 },
851 input1, input2, output);
852 }
853}
854
855inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
856 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
857 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
858 float32x4_t, float32x4_t, const bool),
859 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
860 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
861 float32x4_t, float32x4_t))
862{
863 // Create input windows
864 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
865 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
866
867 // Clear X Dimension on execution window as we handle manually
868 Window win = window;
869 win.set(Window::DimX, Window::Dimension(0, 1, 1));
870
871 const int window_step_x = 16;
872 const auto window_start_x = static_cast<int>(window.x().start());
873 const auto window_end_x = static_cast<int>(window.x().end());
874 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
875
876 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
877
878 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
879 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
880
881 if(is_broadcast_across_x)
882 {
883 // Select the broadcast input on the X axis
884 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
885 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
886 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
887 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
888 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
889
890 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
891 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
892
893 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
894 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
895
896 // Clear X Dimension on execution window as we handle manually
897 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
898
899 Iterator broadcast_input(broadcast_tensor, broadcast_win);
900 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
901 Iterator output(out, win);
902
903 execute_window_loop(win, [&](const Coordinates &)
904 {
905 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
906 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
907
908 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
909 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
910
911 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
912 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
913 for(; x < window_end_x; ++x)
914 {
915 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
916 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
917 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
918 }
919 },
920 broadcast_input, non_broadcast_input, output);
921 }
922 else
923 {
924 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
925 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
926
927 // Input1 quantization info
928 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
929 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
930
931 // Input2 quantization info
932 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
933 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
934
935 // Clear X Dimension on execution window as we handle manually
936 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
937 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
938
939 Iterator input1(in1, input1_win);
940 Iterator input2(in2, input2_win);
941 Iterator output(out, win);
942
943 execute_window_loop(win, [&](const Coordinates &)
944 {
945 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
946 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
947 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
948
949 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
950 vscale1, vscale2, voffseto, invvscaleo);
951 for(; x < window_end_x; ++x)
952 {
953 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
954 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
955 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
956 }
957 },
958 input1, input2, output);
959 }
960}
961
962inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
963 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
964 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
965 float32x4_t, float32x4_t, const bool),
966 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
967 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
968 float32x4_t, float32x4_t))
969{
970 // Create input windows
971 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
972 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
973
974 // Clear X Dimension on execution window as we handle manually
975 Window win = window;
976 win.set(Window::DimX, Window::Dimension(0, 1, 1));
977
978 const int window_step_x = 16;
979 const auto window_start_x = static_cast<int>(window.x().start());
980 const auto window_end_x = static_cast<int>(window.x().end());
981 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
982
983 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
984
985 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
986 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
987
988 if(is_broadcast_across_x)
989 {
990 // Select the broadcast input on the X axis
991 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
992 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
993 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
994 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
995 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
996
997 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
998 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
999
1000 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
1001 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
1002
1003 // Clear X Dimension on execution window as we handle manually
1004 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1005
1006 Iterator broadcast_input(broadcast_tensor, broadcast_win);
1007 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1008 Iterator output(out, win);
1009
1010 execute_window_loop(win, [&](const Coordinates &)
1011 {
1012 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
1013 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1014
1015 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
1016 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
1017
1018 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
1019 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
1020 for(; x < window_end_x; ++x)
1021 {
1022 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
1023 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
1024 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
1025 }
1026 },
1027 broadcast_input, non_broadcast_input, output);
1028 }
1029 else
1030 {
1031 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
1032 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
1033
1034 // Input1 quantization info
1035 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
1036 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
1037
1038 // Input2 quantization info
1039 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
1040 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
1041
1042 // Clear X Dimension on execution window as we handle manually
1043 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1044 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1045
1046 Iterator input1(in1, input1_win);
1047 Iterator input2(in2, input2_win);
1048 Iterator output(out, win);
1049
1050 execute_window_loop(win, [&](const Coordinates &)
1051 {
1052 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
1053 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
1054 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1055
1056 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
1057 vscale1, vscale2, voffseto, invvscaleo);
1058 for(; x < window_end_x; ++x)
1059 {
1060 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
1061 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
1062 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
1063 }
1064 },
1065 input1, input2, output);
1066 }
1067}
1068
1069template <ArithmeticOperation op>
1070void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1071{
1072 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
1073 &elementwise_arithm_op_quantized_broadcast_loop<op>,
1074 &elementwise_arithm_op_quantized_loop<op>);
1075}
1076
1077template <ArithmeticOperation op>
1078void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1079{
1080 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
1081 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1082 &elementwise_arithm_op_quantized_singed_loop<op>);
1083}
1084
1085template <ComparisonOperation op>
1086void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1087{
1088 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1089 &elementwise_comp_op_quantized_broadcast_loop<op>,
1090 &elementwise_comp_op_quantized_loop<op>);
1091}
1092
1093template <ComparisonOperation op>
1094void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1095{
1096 elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1097 &elementwise_comp_op_quantized_signed_broadcast_loop<op>,
1098 &elementwise_comp_op_quantized_signed_loop<op>);
1099}
1100} // namespace cpu
1101} // namespace arm_compute
1102
1103#endif /* SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H */