blob: 98b154e8fdf2bff649b2bbad11c17344bd3ae8b2 [file] [log] [blame]
Dana Zlotnikd5c496d2021-11-28 14:46:12 +02001/*
2 * Copyright (c) 2021-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
25#define SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
26
27#include "src/core/NEON/NEAsymm.h"
28
29namespace arm_compute
30{
31namespace cpu
32{
33template <ArithmeticOperation op, typename VectorType>
34typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
35{
36 using vec_type = typename VectorType::type;
37 using scalar_type = typename VectorType::scalar_type;
38 using tag_type = typename VectorType::tag_type;
39
40 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
41
42 switch(op)
43 {
44 case ArithmeticOperation::MAX:
45 res = wrapper::vmax(a, b);
46 break;
47 case ArithmeticOperation::MIN:
48 res = wrapper::vmin(a, b);
49 break;
50 case ArithmeticOperation::SQUARED_DIFF:
51 {
52 const vec_type tmp = wrapper::vsub(a, b);
53 res = wrapper::vmul(tmp, tmp);
54 break;
55 }
56 case ArithmeticOperation::PRELU:
57 {
58 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
59 const vec_type tmp = wrapper::vmul(a, b);
60 const auto gt = wrapper::vcgt(a, zero);
61
62 res = wrapper::vbsl(gt, a, tmp);
63 break;
64 }
65
66 default:
67 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
68 }
69
70 return res;
71}
Dana Zlotnika538ae52022-02-21 13:12:41 +020072
Dana Zlotnikd5c496d2021-11-28 14:46:12 +020073template <ArithmeticOperation op, typename ScalarType, typename VectorType>
74typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
75{
76 using tag_type = typename VectorType::tag_type;
77 using vec_type = typename VectorType::type;
78
79 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
80 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
81}
82
83template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
84void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
85 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
86 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
87 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
88{
89 // Create input windows
90 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
91 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
92
93 // Clear X Dimension on execution window as we handle manually
94 Window win = window;
95 win.set(Window::DimX, Window::Dimension(0, 1, 1));
96
97 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
98 const auto window_start_x = static_cast<int>(window.x().start());
99 const auto window_end_x = static_cast<int>(window.x().end());
100 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
101
102 if(is_broadcast_across_x)
103 {
104 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
105 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
106 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
107 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
108 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
109
110 // Clear X Dimension on execution window as we handle manually
111 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
112
113 Iterator broadcast_input(broadcast_tensor, broadcast_win);
114 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
115 Iterator output(out, win);
116
117 execute_window_loop(win, [&](const Coordinates &)
118 {
119 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
120 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
121 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
122
123 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
124 for(; x < window_end_x; ++x)
125 {
126 const auto a = *(non_broadcast_input_ptr + x);
127 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
128 }
129 },
130 broadcast_input, non_broadcast_input, output);
131 }
132 else
133 {
134 // Clear X Dimension on execution window as we handle manually
135 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
136 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
137
138 Iterator input1(in1, input1_win);
139 Iterator input2(in2, input2_win);
140 Iterator output(out, win);
141
142 execute_window_loop(win, [&](const Coordinates &)
143 {
144 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
145 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
146 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
147
148 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
149 for(; x < window_end_x; ++x)
150 {
151 const auto a = *(input1_ptr + x);
152 const auto b = *(input2_ptr + x);
153 *(output_ptr + x) = (*scalar_func)(a, b);
154 }
155 },
156 input1, input2, output);
157 }
158}
159
160template <ArithmeticOperation op, typename ScalarType>
161inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
162{
163 auto res = ScalarType(0);
164
165 switch(op)
166 {
167 case ArithmeticOperation::MAX:
168 res = std::max(a, b);
169 break;
170 case ArithmeticOperation::MIN:
171 res = std::min(a, b);
172 break;
173 case ArithmeticOperation::SQUARED_DIFF:
174 {
175 res = (a - b) * (a - b);
176 break;
177 }
178 case ArithmeticOperation::PRELU:
179 {
180 res = (a > 0 ? a : a * b);
181 break;
182 }
183 case ArithmeticOperation::DIV:
184 {
185 res = a / b;
186 if(std::is_integral<ScalarType>::value)
187 {
188 res = (b == 0) ? 0 : res;
189 if(static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0)))
190 {
191 --res;
192 }
193 }
194 break;
195 }
196 case ArithmeticOperation::POWER:
197 {
198 res = std::pow(a, b);
199 break;
200 }
201 default:
202 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
203 }
204 return res;
205}
206
207template <>
208inline int32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a, const int32x4_t &b)
209{
210 return vcvtq_s32_f32(vfloorq_f32(wrapper::vdiv(vcvtq_f32_s32(a), vcvtq_f32_s32(b))));
211}
212
213template <>
214inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
215{
216 return wrapper::vdiv(a, b);
217}
218
219template <>
220inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
221{
222 return wrapper::vpow(a, b);
223}
224
225#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
226template <>
227inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
228{
229 return wrapper::vdiv(a, b);
230}
231
232template <>
233inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
234{
235 return wrapper::vpow(a, b);
236}
237#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
238
239template <ArithmeticOperation op, typename ScalarType, typename VectorType>
240inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
241 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
242{
243 int x = window_start_x;
244 for(; x <= (window_end_x - window_step_x); x += window_step_x)
245 {
246 const auto a = wrapper::vloadq(input1_ptr + x);
247 const auto b = wrapper::vloadq(input2_ptr + x);
248 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
249 }
250 return x;
251}
252
253template <ArithmeticOperation op, typename ScalarType, typename VectorType>
254inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
255 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
256{
257 int x = window_start_x;
258 for(; x <= (window_end_x - window_step_x); x += window_step_x)
259 {
260 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
261 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
262 }
263 return x;
264}
265
266template <ArithmeticOperation op, typename VectorType>
267void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
268{
269 using scalar_type = typename VectorType::scalar_type;
270
271 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
272 &elementwise_arithm_op_scalar<op, scalar_type>,
273 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
274 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
275}
276
277template <ComparisonOperation op, typename InputScalarType>
278inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
279{
280 bool res = false;
281
282 switch(op)
283 {
284 case ComparisonOperation::Equal:
285 res = (a == b);
286 break;
287 case ComparisonOperation::NotEqual:
288 res = (a != b);
289 break;
290 case ComparisonOperation::Greater:
291 res = (a > b);
292 break;
293 case ComparisonOperation::GreaterEqual:
294 res = (a >= b);
295 break;
296 case ComparisonOperation::Less:
297 res = (a < b);
298 break;
299 case ComparisonOperation::LessEqual:
300 res = (a <= b);
301 break;
302 default:
303 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
304 }
305 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
306}
307
308template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
309inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
310{
311 OutputVectorType res = { 0, 0, 0, 0 };
312
313 switch(op)
314 {
315 case ComparisonOperation::Equal:
316 res = wrapper::vceq(a, b);
317 break;
318 case ComparisonOperation::NotEqual:
319 res = wrapper::vnot(wrapper::vceq(a, b));
320 break;
321 case ComparisonOperation::Greater:
322 res = wrapper::vcgt(a, b);
323 break;
324 case ComparisonOperation::GreaterEqual:
325 res = wrapper::vcge(a, b);
326 break;
327 case ComparisonOperation::Less:
328 res = wrapper::vcgt(b, a);
329 break;
330 case ComparisonOperation::LessEqual:
331 res = wrapper::vcge(b, a);
332 break;
333 default:
334 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
335 }
336
337 return res;
338}
339
340template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
341inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
342{
343 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
344 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
345}
346
347template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
348inline int elementwise_comp_op_broadcast_8_loop(int window_start_x, int window_end_x, int window_step_x,
349 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
350{
351 int x = window_start_x;
352 for(; x <= (window_end_x - window_step_x); x += window_step_x)
353 {
354 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
355 wrapper::vstore(output_ptr + x, a);
356 }
357 return x;
358}
359
360template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
361inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
362 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
363{
364 int x = window_start_x;
365 for(; x <= (window_end_x - window_step_x); x += window_step_x)
366 {
367 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
368 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
369 }
370 return x;
371}
372
373template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
374inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
375 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
376{
377 int x = window_start_x;
378 for(; x <= (window_end_x - window_step_x); x += window_step_x)
379 {
380 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
381 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
382 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
383 }
384 if(x <= window_end_x - 4)
385 {
386 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
387 for(int i = 0; i < 4; i++)
388 {
389 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
390 }
391 x = +4;
392 }
393 return x;
394}
395
396template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
397inline int elementwise_comp_op_8_loop(int window_start_x, int window_end_x, int window_step_x,
398 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
399{
400 int x = window_start_x;
401 for(; x <= (window_end_x - window_step_x); x += window_step_x)
402 {
403 const auto a = wrapper::vloadq(input1_ptr + x);
404 const auto b = wrapper::vloadq(input2_ptr + x);
405 const auto res = elementwise_comp_op<op, InputVectorType, uint8x16_t>(a, b);
406 wrapper::vstore(output_ptr + x, res);
407 }
408 return x;
409}
410
411template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
412inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
413 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
414{
415 int x = window_start_x;
416 for(; x <= (window_end_x - window_step_x); x += window_step_x)
417 {
418 const auto a = wrapper::vloadq(input1_ptr + x);
419 const auto b = wrapper::vloadq(input2_ptr + x);
420 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
421 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
422 }
423 return x;
424}
425
426template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
427inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
428 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
429{
430 int x = window_start_x;
431 for(; x <= (window_end_x - window_step_x); x += window_step_x)
432 {
433 auto a = wrapper::vloadq(input1_ptr + x);
434 auto b = wrapper::vloadq(input2_ptr + x);
435 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
436 a = wrapper::vloadq(input1_ptr + x + 4);
437 b = wrapper::vloadq(input2_ptr + x + 4);
438 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
439 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
440 }
441 if(x <= window_end_x - 4)
442 {
443 const auto a = wrapper::vloadq(input1_ptr + x);
444 const auto b = wrapper::vloadq(input2_ptr + x);
445 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
446 for(int i = 0; i < 4; i++)
447 {
448 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
449 }
450 x = +4;
451 }
452 return x;
453}
454
455template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
456void elementwise_comp_op_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
457{
458 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
459 &elementwise_comp_op_scalar<op, InputScalarType>,
460 &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>,
461 &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>);
462}
463
464template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
465void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
466{
467 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
468 &elementwise_comp_op_scalar<op, InputScalarType>,
469 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
470 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
471}
472
473template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
474void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
475{
476 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
477 &elementwise_comp_op_scalar<op, InputScalarType>,
478 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
479 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
480}
481
482inline float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
483{
484 qasymm8x16_t x = vld1q_u8(input1_ptr);
485 const float32x4x4_t out =
486 {
487 {
488 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
489 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
490 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
491 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
492 }
493 };
494 return out;
495}
496
497inline float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
498{
499 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
500 const float32x4x4_t out =
501 {
502 {
503 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
504 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
505 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
506 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
507 }
508 };
509 return out;
510}
511
512inline void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
513{
514 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
515 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
516 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
517}
518
519inline void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
520{
521 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
522 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
523 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
524}
525
526inline void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
527{
528 int32x4x4_t out =
529 {
530 {
531 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
532 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
533 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
534 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
535 }
536 };
537 store_quantized(output_ptr, out);
538}
539
540inline void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
541{
542 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
543 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
544 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
545}
546
547inline void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
548{
549 int32x4x4_t out =
550 {
551 {
552 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
553 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
554 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
555 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
556 }
557 };
558 store_quantized_signed(output_ptr, out);
559}
560
561template <ArithmeticOperation op>
562inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
563{
564 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
565}
566
567template <ArithmeticOperation op>
568inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
569{
570 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
571}
572
573template <ArithmeticOperation op>
574float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
575{
576 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
577 float32x4x4_t out =
578 {
579 {
580 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
581 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
582 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
583 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
584 }
585 };
586 return out;
587}
588
589template <ComparisonOperation op>
590inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
591{
592 ARM_COMPUTE_UNUSED(qinfo);
593 return elementwise_comp_op_scalar<op>(a, b);
594}
595
596template <ComparisonOperation op>
597inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
598{
599 uint32x4x4_t out =
600 {
601 {
602 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
603 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
604 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
605 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
606 }
607 };
608 return out;
609}
610
611template <ArithmeticOperation op>
612inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
613 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
614 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
615 float32x4_t voffseto, float32x4_t invvscaleo)
616{
617 int x = window_start_x;
618 for(; x <= (window_end_x - window_step_x); x += window_step_x)
619 {
620 // Get inputs and compute output
621 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
622 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
623 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
624 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
625 }
626 return x;
627}
628
629template <ArithmeticOperation op>
630inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
631 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
632 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
633 float32x4_t voffseto, float32x4_t invvscaleo)
634{
635 int x = window_start_x;
636 for(; x <= (window_end_x - window_step_x); x += window_step_x)
637 {
638 // Get inputs and compute output
639 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
640 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
641 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
642 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
643 }
644 return x;
645}
646
647template <ArithmeticOperation op>
648inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
649 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
650 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
651 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
652{
653 int x = window_start_x;
654 for(; x <= (window_end_x - window_step_x); x += window_step_x)
655 {
656 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
657 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
658 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
659 }
660 return x;
661}
662template <ArithmeticOperation op>
663inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
664 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
665 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
666 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
667{
668 int x = window_start_x;
669 for(; x <= (window_end_x - window_step_x); x += window_step_x)
670 {
671 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
672 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
673 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
674 }
675 return x;
676}
677
678template <ComparisonOperation op>
679inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
680 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
681 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
682 float32x4_t voffseto, float32x4_t invvscaleo)
683{
684 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
685 int x = window_start_x;
686 for(; x <= (window_end_x - window_step_x); x += window_step_x)
687 {
688 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
689 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
690 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
691 store_quantized(output_ptr + x, rf);
692 }
693 return x;
694}
695
696template <ComparisonOperation op>
697inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
698 const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
699 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
700 float32x4_t voffseto, float32x4_t invvscaleo)
701{
702 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
703 int x = window_start_x;
704 for(; x <= (window_end_x - window_step_x); x += window_step_x)
705 {
706 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
707 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
708 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
709 store_quantized(output_ptr + x, rf);
710 }
711 return x;
712}
713
714template <ComparisonOperation op>
715inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
716 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
717 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
718 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
719{
720 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
721 int x = window_start_x;
722 for(; x <= (window_end_x - window_step_x); x += window_step_x)
723 {
724 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
725 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
726 store_quantized(output_ptr + x, rf);
727 }
728 return x;
729}
730
731template <ComparisonOperation op>
732inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
733 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
734 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
735 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
736{
737 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
738 int x = window_start_x;
739 for(; x <= (window_end_x - window_step_x); x += window_step_x)
740 {
741 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
742 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
743 store_quantized(output_ptr + x, rf);
744 }
745 return x;
746}
747
748inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
749 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
750 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
751 float32x4_t, float32x4_t, const bool),
752 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
753 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
754 float32x4_t, float32x4_t))
755{
756 // Create input windows
757 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
758 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
759
760 // Clear X Dimension on execution window as we handle manually
761 Window win = window;
762 win.set(Window::DimX, Window::Dimension(0, 1, 1));
763
764 const int window_step_x = 16;
765 const auto window_start_x = static_cast<int>(window.x().start());
766 const auto window_end_x = static_cast<int>(window.x().end());
767 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
768
769 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
770
771 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
772 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
773 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
774
775 if(is_broadcast_across_x)
776 {
777 // Select the broadcast input on the X axis
778 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
779 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
780 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
781 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
782 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
783
784 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
785 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
786
787 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
788 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
789
790 // Clear X Dimension on execution window as we handle manually
791 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
792
793 Iterator broadcast_input(broadcast_tensor, broadcast_win);
794 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
795 Iterator output(out, win);
796
797 execute_window_loop(win, [&](const Coordinates &)
798 {
799 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
800 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
801
802 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
803 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
804
805 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
806 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
807 for(; x < window_end_x; ++x)
808 {
809 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
810 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
811 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
812 }
813 },
814 broadcast_input, non_broadcast_input, output);
815 }
816 else
817 {
818 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
819 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
820
821 // Input1 quantization info
822 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
823 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
824
825 // Input2 quantization info
826 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
827 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
828
829 // Clear X Dimension on execution window as we handle manually
830 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
831 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
832
833 Iterator input1(in1, input1_win);
834 Iterator input2(in2, input2_win);
835 Iterator output(out, win);
836
837 execute_window_loop(win, [&](const Coordinates &)
838 {
839 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
840 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
841 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
842
843 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
844 vscale1, vscale2, voffseto, invvscaleo);
845 for(; x < window_end_x; ++x)
846 {
847 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
848 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
849 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
850 }
851 },
852 input1, input2, output);
853 }
854}
855
856inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
857 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
858 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
859 float32x4_t, float32x4_t, const bool),
860 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
861 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
862 float32x4_t, float32x4_t))
863{
864 // Create input windows
865 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
866 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
867
868 // Clear X Dimension on execution window as we handle manually
869 Window win = window;
870 win.set(Window::DimX, Window::Dimension(0, 1, 1));
871
872 const int window_step_x = 16;
873 const auto window_start_x = static_cast<int>(window.x().start());
874 const auto window_end_x = static_cast<int>(window.x().end());
875 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
876
877 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
878
879 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
880 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
881
882 if(is_broadcast_across_x)
883 {
884 // Select the broadcast input on the X axis
885 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
886 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
887 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
888 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
889 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
890
891 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
892 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
893
894 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
895 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
896
897 // Clear X Dimension on execution window as we handle manually
898 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
899
900 Iterator broadcast_input(broadcast_tensor, broadcast_win);
901 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
902 Iterator output(out, win);
903
904 execute_window_loop(win, [&](const Coordinates &)
905 {
906 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
907 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
908
909 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
910 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
911
912 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
913 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
914 for(; x < window_end_x; ++x)
915 {
916 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
917 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
918 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
919 }
920 },
921 broadcast_input, non_broadcast_input, output);
922 }
923 else
924 {
925 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
926 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
927
928 // Input1 quantization info
929 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
930 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
931
932 // Input2 quantization info
933 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
934 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
935
936 // Clear X Dimension on execution window as we handle manually
937 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
938 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
939
940 Iterator input1(in1, input1_win);
941 Iterator input2(in2, input2_win);
942 Iterator output(out, win);
943
944 execute_window_loop(win, [&](const Coordinates &)
945 {
946 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
947 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
948 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
949
950 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
951 vscale1, vscale2, voffseto, invvscaleo);
952 for(; x < window_end_x; ++x)
953 {
954 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
955 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
956 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
957 }
958 },
959 input1, input2, output);
960 }
961}
962
963inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
964 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
965 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
966 float32x4_t, float32x4_t, const bool),
967 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
968 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
969 float32x4_t, float32x4_t))
970{
971 // Create input windows
972 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
973 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
974
975 // Clear X Dimension on execution window as we handle manually
976 Window win = window;
977 win.set(Window::DimX, Window::Dimension(0, 1, 1));
978
979 const int window_step_x = 16;
980 const auto window_start_x = static_cast<int>(window.x().start());
981 const auto window_end_x = static_cast<int>(window.x().end());
982 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
983
984 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
985
986 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
987 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
988
989 if(is_broadcast_across_x)
990 {
991 // Select the broadcast input on the X axis
992 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
993 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
994 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
995 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
996 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
997
998 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
999 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
1000
1001 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
1002 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
1003
1004 // Clear X Dimension on execution window as we handle manually
1005 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1006
1007 Iterator broadcast_input(broadcast_tensor, broadcast_win);
1008 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1009 Iterator output(out, win);
1010
1011 execute_window_loop(win, [&](const Coordinates &)
1012 {
1013 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
1014 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1015
1016 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
1017 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
1018
1019 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
1020 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
1021 for(; x < window_end_x; ++x)
1022 {
1023 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
1024 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
1025 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
1026 }
1027 },
1028 broadcast_input, non_broadcast_input, output);
1029 }
1030 else
1031 {
1032 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
1033 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
1034
1035 // Input1 quantization info
1036 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
1037 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
1038
1039 // Input2 quantization info
1040 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
1041 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
1042
1043 // Clear X Dimension on execution window as we handle manually
1044 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1045 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1046
1047 Iterator input1(in1, input1_win);
1048 Iterator input2(in2, input2_win);
1049 Iterator output(out, win);
1050
1051 execute_window_loop(win, [&](const Coordinates &)
1052 {
1053 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
1054 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
1055 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1056
1057 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
1058 vscale1, vscale2, voffseto, invvscaleo);
1059 for(; x < window_end_x; ++x)
1060 {
1061 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
1062 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
1063 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
1064 }
1065 },
1066 input1, input2, output);
1067 }
1068}
1069
1070template <ArithmeticOperation op>
1071void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1072{
1073 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
1074 &elementwise_arithm_op_quantized_broadcast_loop<op>,
1075 &elementwise_arithm_op_quantized_loop<op>);
1076}
1077
1078template <ArithmeticOperation op>
1079void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1080{
1081 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
1082 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1083 &elementwise_arithm_op_quantized_singed_loop<op>);
1084}
1085
1086template <ComparisonOperation op>
1087void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1088{
1089 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1090 &elementwise_comp_op_quantized_broadcast_loop<op>,
1091 &elementwise_comp_op_quantized_loop<op>);
1092}
1093
1094template <ComparisonOperation op>
1095void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1096{
1097 elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1098 &elementwise_comp_op_quantized_signed_broadcast_loop<op>,
1099 &elementwise_comp_op_quantized_signed_loop<op>);
1100}
1101} // namespace cpu
1102} // namespace arm_compute
1103
1104#endif /* SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H */