blob: 4928ae9bddc902170884774eeac8ea2a4b6b483d [file] [log] [blame]
giuros0192fd9432018-12-03 17:30:00 +00001/*
George Wortd88590f2018-12-12 17:39:58 +00002 * Copyright (c) 2018-2019 ARM Limited.
giuros0192fd9432018-12-03 17:30:00 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEElementwiseOperationKernel.h"
25
26#include "arm_compute/core/CPP/Validate.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEAsymm.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/NEON/wrapper/wrapper.h"
34#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Validate.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39#include <cstdint>
40#include <map>
41#include <string>
42
43namespace arm_compute
44{
45class Coordinates;
46
47namespace
48{
49float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
50{
51 qasymm8x16_t x = vld1q_u8(input1_ptr);
52 const float32x4x4_t out =
53 {
54 {
55 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
56 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
57 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
58 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
59 }
60 };
61 return out;
62}
63
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000064float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
65{
66 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
67 const float32x4x4_t out =
68 {
69 {
70 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
71 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
72 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
73 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
74 }
75 };
76 return out;
77}
78
George Wortd88590f2018-12-12 17:39:58 +000079void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
80{
81 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
82 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
83 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
84}
85
86void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
87{
88 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
89 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
90 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
91}
92
giuros0192fd9432018-12-03 17:30:00 +000093void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
94{
95 int32x4x4_t out =
96 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +000097 {
98 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
99 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
100 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
101 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
102 }
giuros0192fd9432018-12-03 17:30:00 +0000103 };
George Wortd88590f2018-12-12 17:39:58 +0000104 store_quantized(output_ptr, out);
giuros0192fd9432018-12-03 17:30:00 +0000105}
106
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000107void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
108{
109 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
110 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
111 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
112}
113
114void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
115{
116 int32x4x4_t out =
117 {
118 {
119 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
120 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
121 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
122 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
123 }
124 };
125 store_quantized_signed(output_ptr, out);
126}
127
giuros0192fd9432018-12-03 17:30:00 +0000128float32x4x4_t dup_quantized(qasymm8_t broadcast_value, int offset, float scale)
129{
130 const qasymm8x16_t broadcast_value_vec = vdupq_n_u8(broadcast_value);
131 const int32x4_t voffset = vdupq_n_s32(offset);
132 const float32x4_t vscale = vdupq_n_f32(scale);
133
134 const float32x4x4_t broadcast_vector =
135 {
136 {
137 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(broadcast_value_vec))))), voffset)), vscale),
138 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(broadcast_value_vec))))), voffset)), vscale),
139 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(broadcast_value_vec))))), voffset)), vscale),
140 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(broadcast_value_vec))))), voffset)), vscale),
141 }
142 };
143 return broadcast_vector;
144}
145
146template <ArithmeticOperation op, typename ScalarType>
George Wortd88590f2018-12-12 17:39:58 +0000147inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
giuros0192fd9432018-12-03 17:30:00 +0000148{
149 auto res = ScalarType(0);
150
151 switch(op)
152 {
153 case ArithmeticOperation::MAX:
154 res = std::max(a, b);
155 break;
156 case ArithmeticOperation::MIN:
157 res = std::min(a, b);
158 break;
159 case ArithmeticOperation::SQUARED_DIFF:
160 {
161 res = (a - b) * (a - b);
162 break;
163 }
giuros01d5134362019-05-14 16:12:53 +0100164 case ArithmeticOperation::PRELU:
165 {
166 res = (a > 0 ? a : a * b);
167 break;
168 }
George Worta1e7e282019-01-15 11:00:29 +0000169 case ArithmeticOperation::DIV:
170 {
171 res = a / b;
172 break;
173 }
Usama Arif81e671e2019-05-13 13:33:14 +0100174 case ArithmeticOperation::POWER:
175 {
176 res = std::pow(a, b);
177 break;
178 }
giuros0192fd9432018-12-03 17:30:00 +0000179 default:
180 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
181 }
182 return res;
183}
184
George Wortd88590f2018-12-12 17:39:58 +0000185template <ArithmeticOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100186inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000187{
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100188 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
George Wortd88590f2018-12-12 17:39:58 +0000189}
190
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000191template <ArithmeticOperation op>
192inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
193{
194 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
195}
196
giuros01d5134362019-05-14 16:12:53 +0100197template <ArithmeticOperation op, typename VectorType>
198inline typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
giuros0192fd9432018-12-03 17:30:00 +0000199{
giuros01d5134362019-05-14 16:12:53 +0100200 using vec_type = typename VectorType::type;
201 using scalar_type = typename VectorType::scalar_type;
202 using tag_type = typename VectorType::tag_type;
203
204 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
giuros0192fd9432018-12-03 17:30:00 +0000205
206 switch(op)
207 {
208 case ArithmeticOperation::MAX:
209 res = wrapper::vmax(a, b);
210 break;
211 case ArithmeticOperation::MIN:
212 res = wrapper::vmin(a, b);
213 break;
214 case ArithmeticOperation::SQUARED_DIFF:
215 {
giuros01d5134362019-05-14 16:12:53 +0100216 const vec_type tmp = wrapper::vsub(a, b);
217 res = wrapper::vmul(tmp, tmp);
giuros0192fd9432018-12-03 17:30:00 +0000218 break;
219 }
giuros01d5134362019-05-14 16:12:53 +0100220 case ArithmeticOperation::PRELU:
221 {
222 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
223 const vec_type tmp = wrapper::vmul(a, b);
224 const auto gt = wrapper::vcgt(a, zero);
225
226 res = wrapper::vbsl(gt, a, tmp);
227 break;
228 }
229
giuros0192fd9432018-12-03 17:30:00 +0000230 default:
231 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
232 }
233
234 return res;
235}
236
George Worta1e7e282019-01-15 11:00:29 +0000237template <>
giuros01d5134362019-05-14 16:12:53 +0100238inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000239{
240 return wrapper::vdiv(a, b);
241}
242
Usama Arif81e671e2019-05-13 13:33:14 +0100243template <>
giuros01d5134362019-05-14 16:12:53 +0100244inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100245{
246 return wrapper::vpow(a, b);
247}
248
George Worta1e7e282019-01-15 11:00:29 +0000249#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
250template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100251inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000252{
253 return wrapper::vdiv(a, b);
254}
Usama Arif81e671e2019-05-13 13:33:14 +0100255
256template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100257inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100258{
259 return wrapper::vpow(a, b);
260}
George Worta1e7e282019-01-15 11:00:29 +0000261#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
262
giuros0192fd9432018-12-03 17:30:00 +0000263template <ArithmeticOperation op>
George Wortd88590f2018-12-12 17:39:58 +0000264inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
giuros0192fd9432018-12-03 17:30:00 +0000265{
giuros01d5134362019-05-14 16:12:53 +0100266 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
giuros0192fd9432018-12-03 17:30:00 +0000267 float32x4x4_t out =
268 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000269 {
giuros01d5134362019-05-14 16:12:53 +0100270 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
271 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
272 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
273 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000274 }
giuros0192fd9432018-12-03 17:30:00 +0000275 };
276 return out;
277}
278
giuros01d5134362019-05-14 16:12:53 +0100279template <ArithmeticOperation op, typename ScalarType, typename VectorType>
280inline typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
George Wortd88590f2018-12-12 17:39:58 +0000281{
giuros01d5134362019-05-14 16:12:53 +0100282 using tag_type = typename VectorType::tag_type;
283 using vec_type = typename VectorType::type;
284
285 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
286 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
George Wortd88590f2018-12-12 17:39:58 +0000287}
288
289template <ComparisonOperation op, typename InputScalarType>
290inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
291{
292 bool res = false;
293
294 switch(op)
295 {
296 case ComparisonOperation::Equal:
297 res = (a == b);
298 break;
299 case ComparisonOperation::NotEqual:
300 res = (a != b);
301 break;
302 case ComparisonOperation::Greater:
303 res = (a > b);
304 break;
305 case ComparisonOperation::GreaterEqual:
306 res = (a >= b);
307 break;
308 case ComparisonOperation::Less:
309 res = (a < b);
310 break;
311 case ComparisonOperation::LessEqual:
312 res = (a <= b);
313 break;
314 default:
315 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
316 }
317 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
318}
319
320template <ComparisonOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100321inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000322{
323 ARM_COMPUTE_UNUSED(qinfo);
324 return elementwise_comp_op_scalar<op>(a, b);
325}
326
327template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
328inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
329{
330 OutputVectorType res = { 0, 0, 0, 0 };
331
332 switch(op)
333 {
334 case ComparisonOperation::Equal:
335 res = wrapper::vceq(a, b);
336 break;
337 case ComparisonOperation::NotEqual:
338 res = wrapper::vnot(wrapper::vceq(a, b));
339 break;
340 case ComparisonOperation::Greater:
341 res = wrapper::vcgt(a, b);
342 break;
343 case ComparisonOperation::GreaterEqual:
344 res = wrapper::vcge(a, b);
345 break;
346 case ComparisonOperation::Less:
347 res = wrapper::vcgt(b, a);
348 break;
349 case ComparisonOperation::LessEqual:
350 res = wrapper::vcge(b, a);
351 break;
352 default:
353 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
354 }
355
356 return res;
357}
358
359template <ComparisonOperation op>
360inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
361{
362 uint32x4x4_t out =
363 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000364 {
365 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
366 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
367 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
368 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
369 }
George Wortd88590f2018-12-12 17:39:58 +0000370 };
371 return out;
372}
373
374template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
375inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
376{
377 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
378 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
379}
380
381template <ArithmeticOperation op, typename ScalarType, typename VectorType>
382inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
383 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
384{
385 int x = window_start_x;
386 for(; x <= (window_end_x - window_step_x); x += window_step_x)
387 {
388 const auto a = wrapper::vloadq(input1_ptr + x);
389 const auto b = wrapper::vloadq(input2_ptr + x);
giuros01d5134362019-05-14 16:12:53 +0100390 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
George Wortd88590f2018-12-12 17:39:58 +0000391 }
392 return x;
393}
394
395template <ArithmeticOperation op>
396inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
397 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
398 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
399 float32x4_t voffseto, float32x4_t invvscaleo)
400{
401 int x = window_start_x;
402 for(; x <= (window_end_x - window_step_x); x += window_step_x)
403 {
404 // Get inputs and compute output
405 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
406 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
407 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
408 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
409 }
410 return x;
411}
412
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000413template <ArithmeticOperation op>
414inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
415 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
416 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
417 float32x4_t voffseto, float32x4_t invvscaleo)
418{
419 int x = window_start_x;
420 for(; x <= (window_end_x - window_step_x); x += window_step_x)
421 {
422 // Get inputs and compute output
423 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
424 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
425 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
426 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
427 }
428 return x;
429}
430
George Wortd88590f2018-12-12 17:39:58 +0000431template <ArithmeticOperation op, typename ScalarType, typename VectorType>
432inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
433 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
434{
435 int x = window_start_x;
436 for(; x <= (window_end_x - window_step_x); x += window_step_x)
437 {
438 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
giuros01d5134362019-05-14 16:12:53 +0100439 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
George Wortd88590f2018-12-12 17:39:58 +0000440 }
441 return x;
442}
443
444template <ArithmeticOperation op>
445inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
446 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
447 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
448 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
449{
450 int x = window_start_x;
451 for(; x <= (window_end_x - window_step_x); x += window_step_x)
452 {
453 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
454 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
455 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
456 }
457 return x;
458}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000459template <ArithmeticOperation op>
460inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
461 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
462 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
463 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
464{
465 int x = window_start_x;
466 for(; x <= (window_end_x - window_step_x); x += window_step_x)
467 {
468 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
469 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
470 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
471 }
472 return x;
473}
George Wortd88590f2018-12-12 17:39:58 +0000474
475template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
476inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
477 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
478{
479 int x = window_start_x;
480 for(; x <= (window_end_x - window_step_x); x += window_step_x)
481 {
482 const auto a = wrapper::vloadq(input1_ptr + x);
483 const auto b = wrapper::vloadq(input2_ptr + x);
484 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
485 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
486 }
487 return x;
488}
489
490template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
491inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
492 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
493{
494 int x = window_start_x;
495 for(; x <= (window_end_x - window_step_x); x += window_step_x)
496 {
497 auto a = wrapper::vloadq(input1_ptr + x);
498 auto b = wrapper::vloadq(input2_ptr + x);
499 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
500 a = wrapper::vloadq(input1_ptr + x + 4);
501 b = wrapper::vloadq(input2_ptr + x + 4);
502 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
503 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
504 }
505 if(x <= window_end_x - 4)
506 {
507 const auto a = wrapper::vloadq(input1_ptr + x);
508 const auto b = wrapper::vloadq(input2_ptr + x);
509 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
510 for(int i = 0; i < 4; i++)
511 {
512 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
513 }
514 x = +4;
515 }
516 return x;
517}
518
519template <ComparisonOperation op>
520inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
521 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
522 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
523 float32x4_t voffseto, float32x4_t invvscaleo)
524{
525 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
526 int x = window_start_x;
527 for(; x <= (window_end_x - window_step_x); x += window_step_x)
528 {
529 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
530 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
531 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
532 store_quantized(output_ptr + x, rf);
533 }
534 return x;
535}
536
537template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
538inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
539 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
540{
541 int x = window_start_x;
542 for(; x <= (window_end_x - window_step_x); x += window_step_x)
543 {
544 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
545 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
546 }
547 return x;
548}
549
550template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
551inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
552 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
553{
554 int x = window_start_x;
555 for(; x <= (window_end_x - window_step_x); x += window_step_x)
556 {
557 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
558 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
559 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
560 }
561 if(x <= window_end_x - 4)
562 {
563 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
564 for(int i = 0; i < 4; i++)
565 {
566 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
567 }
568 x = +4;
569 }
570 return x;
571}
572
573template <ComparisonOperation op>
574inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
575 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
576 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
577 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
578{
579 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
580 int x = window_start_x;
581 for(; x <= (window_end_x - window_step_x); x += window_step_x)
582 {
583 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
584 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
585 store_quantized(output_ptr + x, rf);
586 }
587 return x;
588}
589
590template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
591void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
592 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
593 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
594 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
giuros0192fd9432018-12-03 17:30:00 +0000595{
596 // Create input windows
597 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
598 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
599
600 // Clear X Dimension on execution window as we handle manually
601 Window win = window;
602 win.set(Window::DimX, Window::Dimension(0, 1, 1));
603
Michalis Spyroue8c0c432019-01-22 11:08:31 +0000604 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
giuros0192fd9432018-12-03 17:30:00 +0000605 const auto window_start_x = static_cast<int>(window.x().start());
606 const auto window_end_x = static_cast<int>(window.x().end());
607 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
608
609 if(is_broadcast_across_x)
610 {
giuros0192fd9432018-12-03 17:30:00 +0000611 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
612 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
613 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
614 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
615 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
616
617 // Clear X Dimension on execution window as we handle manually
618 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
619
620 Iterator broadcast_input(broadcast_tensor, broadcast_win);
621 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
622 Iterator output(out, win);
623
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100624 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000625 {
George Wortd88590f2018-12-12 17:39:58 +0000626 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
627 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
628 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000629
George Wortd88590f2018-12-12 17:39:58 +0000630 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000631 for(; x < window_end_x; ++x)
632 {
633 const auto a = *(non_broadcast_input_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000634 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
giuros0192fd9432018-12-03 17:30:00 +0000635 }
636 },
637 broadcast_input, non_broadcast_input, output);
638 }
639 else
640 {
641 // Clear X Dimension on execution window as we handle manually
642 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
643 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
644
645 Iterator input1(in1, input1_win);
646 Iterator input2(in2, input2_win);
647 Iterator output(out, win);
648
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100649 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000650 {
George Wortd88590f2018-12-12 17:39:58 +0000651 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
652 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
653 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000654
George Wortd88590f2018-12-12 17:39:58 +0000655 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
giuros0192fd9432018-12-03 17:30:00 +0000656 for(; x < window_end_x; ++x)
657 {
658 const auto a = *(input1_ptr + x);
659 const auto b = *(input2_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000660 *(output_ptr + x) = (*scalar_func)(a, b);
giuros0192fd9432018-12-03 17:30:00 +0000661 }
giuros0192fd9432018-12-03 17:30:00 +0000662 },
663 input1, input2, output);
664 }
665}
666
George Wortd88590f2018-12-12 17:39:58 +0000667void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100668 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
George Wortd88590f2018-12-12 17:39:58 +0000669 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
670 float32x4_t, float32x4_t, const bool),
671 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
672 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
673 float32x4_t, float32x4_t))
giuros0192fd9432018-12-03 17:30:00 +0000674{
675 // Create input windows
676 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
677 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
678
679 // Clear X Dimension on execution window as we handle manually
680 Window win = window;
681 win.set(Window::DimX, Window::Dimension(0, 1, 1));
682
683 const int window_step_x = 16;
684 const auto window_start_x = static_cast<int>(window.x().start());
685 const auto window_end_x = static_cast<int>(window.x().end());
686 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
687
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100688 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000689
690 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100691 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
692 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000693
694 if(is_broadcast_across_x)
695 {
696 // Select the broadcast input on the X axis
697 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
698 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
699 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
700 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
701 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
702
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100703 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
704 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000705
706 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
707 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
708
709 // Clear X Dimension on execution window as we handle manually
710 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
711
712 Iterator broadcast_input(broadcast_tensor, broadcast_win);
713 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
714 Iterator output(out, win);
715
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100716 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000717 {
718 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
719 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
720
721 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
722 const float32x4x4_t broadcast_vector = dup_quantized(broadcast_value, broadcast_qinfo.offset, broadcast_qinfo.scale);
723
George Wortd88590f2018-12-12 17:39:58 +0000724 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
725 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000726 for(; x < window_end_x; ++x)
727 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100728 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
729 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
730 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000731 }
732 },
733 broadcast_input, non_broadcast_input, output);
734 }
735 else
736 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100737 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
738 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
739
giuros0192fd9432018-12-03 17:30:00 +0000740 // Input1 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100741 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
742 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000743
744 // Input2 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100745 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
746 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000747
748 // Clear X Dimension on execution window as we handle manually
749 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
750 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
751
giuros0192fd9432018-12-03 17:30:00 +0000752 Iterator input1(in1, input1_win);
753 Iterator input2(in2, input2_win);
754 Iterator output(out, win);
755
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100756 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000757 {
758 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
759 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
760 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
761
George Wortd88590f2018-12-12 17:39:58 +0000762 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
763 vscale1, vscale2, voffseto, invvscaleo);
giuros0192fd9432018-12-03 17:30:00 +0000764 for(; x < window_end_x; ++x)
765 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100766 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
767 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
768 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000769 }
770 },
771 input1, input2, output);
772 }
773}
774
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000775void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
776 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
777 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
778 float32x4_t, float32x4_t, const bool),
779 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
780 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
781 float32x4_t, float32x4_t))
782{
783 // Create input windows
784 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
785 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
786
787 // Clear X Dimension on execution window as we handle manually
788 Window win = window;
789 win.set(Window::DimX, Window::Dimension(0, 1, 1));
790
791 const int window_step_x = 16;
792 const auto window_start_x = static_cast<int>(window.x().start());
793 const auto window_end_x = static_cast<int>(window.x().end());
794 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
795
796 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
797
798 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
799 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
800 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
801
802 if(is_broadcast_across_x)
803 {
804 // Select the broadcast input on the X axis
805 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
806 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
807 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
808 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
809 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
810
811 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
812 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
813
814 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
815 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
816
817 // Clear X Dimension on execution window as we handle manually
818 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
819
820 Iterator broadcast_input(broadcast_tensor, broadcast_win);
821 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
822 Iterator output(out, win);
823
824 execute_window_loop(win, [&](const Coordinates &)
825 {
826 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
827 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
828
829 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
830 const float32x4x4_t broadcast_vector = dup_quantized(broadcast_value, broadcast_qinfo.offset, broadcast_qinfo.scale);
831
832 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
833 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
834 for(; x < window_end_x; ++x)
835 {
836 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
837 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
838 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
839 }
840 },
841 broadcast_input, non_broadcast_input, output);
842 }
843 else
844 {
845 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
846 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
847
848 // Input1 quantization info
849 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
850 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
851
852 // Input2 quantization info
853 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
854 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
855
856 // Clear X Dimension on execution window as we handle manually
857 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
858 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
859
860 Iterator input1(in1, input1_win);
861 Iterator input2(in2, input2_win);
862 Iterator output(out, win);
863
864 execute_window_loop(win, [&](const Coordinates &)
865 {
866 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
867 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
868 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
869
870 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
871 vscale1, vscale2, voffseto, invvscaleo);
872 for(; x < window_end_x; ++x)
873 {
874 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
875 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
876 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
877 }
878 },
879 input1, input2, output);
880 }
881}
882
George Wortd88590f2018-12-12 17:39:58 +0000883template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
884void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
giuros0192fd9432018-12-03 17:30:00 +0000885{
George Wortd88590f2018-12-12 17:39:58 +0000886 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
887 &elementwise_comp_op_scalar<op, InputScalarType>,
888 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
889 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
890}
891
892template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
893void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
894{
895 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
896 &elementwise_comp_op_scalar<op, InputScalarType>,
897 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
898 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
899}
900
giuros01d5134362019-05-14 16:12:53 +0100901template <ArithmeticOperation op, typename VectorType>
George Wortd88590f2018-12-12 17:39:58 +0000902void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
903{
giuros01d5134362019-05-14 16:12:53 +0100904 using scalar_type = typename VectorType::scalar_type;
905
906 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
907 &elementwise_arithm_op_scalar<op, scalar_type>,
908 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
909 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
George Wortd88590f2018-12-12 17:39:58 +0000910}
911
912template <ArithmeticOperation op>
913void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
914{
915 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
916 &elementwise_arithm_op_quantized_broadcast_loop<op>,
917 &elementwise_arithm_op_quantized_loop<op>);
918}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000919template <ArithmeticOperation op>
920void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
921{
922 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
923 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
924 &elementwise_arithm_op_quantized_singed_loop<op>);
925}
George Wortd88590f2018-12-12 17:39:58 +0000926
927template <ComparisonOperation op>
928void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
929{
930 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
931 &elementwise_comp_op_quantized_broadcast_loop<op>,
932 &elementwise_comp_op_quantized_loop<op>);
933}
934
935std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
936configure_func(const ITensor *input1, const ITensor *input2, ITensor *output,
937 std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function)
938{
939 std::string function_to_call("op_");
940 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
941 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
942 function_to_call += string_from_data_type(output->info()->data_type());
943
944 auto it = map_function.find(function_to_call);
945
946 if(it != map_function.end())
947 {
948 auto func = it->second;
949 return [func](const ITensor * input1, const ITensor * input2, ITensor * output, const Window & window)
950 {
951 func(input1, input2, output, window);
952 };
953 }
954 return nullptr;
955}
956
957template <ArithmeticOperation op>
958std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
959configure_arithm_func(const ITensor *input1, const ITensor *input2, ITensor *output)
960{
961 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
962 {
giuros01d5134362019-05-14 16:12:53 +0100963 { "op_F32_F32_F32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float, 4>> },
964 { "op_S16_S16_S16", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int16_t, 8>> },
965 { "op_S32_S32_S32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int32_t, 4>> },
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000966 { "op_QASYMM8_QASYMM8_QASYMM8", &elementwise_arithm_op_quantized<op> },
967 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED", &elementwise_arithm_op_quantized_signed<op> }
George Wortd88590f2018-12-12 17:39:58 +0000968 };
969#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
giuros01d5134362019-05-14 16:12:53 +0100970 map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float16_t, 8>>;
George Wortd88590f2018-12-12 17:39:58 +0000971#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
972
973 return configure_func(input1, input2, output, map_function);
974}
975
976template <ComparisonOperation op>
977std::function<void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)>
978configure_comp_func(const ITensor *input1, const ITensor *input2, ITensor *output)
979{
980 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
981 {
982 { "op_F32_F32_U8", &elementwise_comp_op_32<op, float, float32x4_t> },
983 { "op_S16_S16_U8", &elementwise_comp_op_16<op, int16_t, int16x8_t> },
984 { "op_S32_S32_U8", &elementwise_comp_op_32<op, int32_t, int32x4_t> },
985 { "op_QASYMM8_QASYMM8_U8", &elementwise_comp_op_quantized<op> }
986 };
987#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
988 map_function["op_F16_F16_U8"] = &elementwise_comp_op_16<op, float16_t, float16x8_t>;
989#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
990
991 return configure_func(input1, input2, output, map_function);
992}
993} // namespace
994
995NEElementwiseOperationKernel::NEElementwiseOperationKernel()
996 : _function(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
997{
998}
999
1000Status NEElementwiseOperationKernel::validate_arguments_common(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1001{
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001002 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
1003 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
George Wortd88590f2018-12-12 17:39:58 +00001004 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
giuros0192fd9432018-12-03 17:30:00 +00001005 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
1006
1007 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
1008
1009 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1010
1011 // Validate in case of configured output
1012 if(output.total_size() > 0)
1013 {
giuros0192fd9432018-12-03 17:30:00 +00001014 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
1015 "Wrong shape for output");
1016 }
1017
1018 return Status{};
1019}
giuros0192fd9432018-12-03 17:30:00 +00001020
giuros0192fd9432018-12-03 17:30:00 +00001021void NEElementwiseOperationKernel::configure_common(const ITensor *input1, const ITensor *input2, ITensor *output)
1022{
1023 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001024
1025 // Configure kernel window
1026 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1->info(), *input2->info());
1027 const TensorShape &out_shape = broadcast_pair.first;
1028 const ValidRegion &valid_region = broadcast_pair.second;
1029
1030 // Auto initialize output if not initialized
1031 auto_init_if_empty(*output->info(), out_shape, 1, input1->info()->data_type());
1032
1033 Window win = calculate_max_window(valid_region);
1034
giuros0192fd9432018-12-03 17:30:00 +00001035 _input1 = input1;
1036 _input2 = input2;
1037 _output = output;
1038
giuros0192fd9432018-12-03 17:30:00 +00001039 INEKernel::configure(win);
1040}
1041
1042void NEElementwiseOperationKernel::run(const Window &window, const ThreadInfo &info)
1043{
George Wortd88590f2018-12-12 17:39:58 +00001044 ARM_COMPUTE_UNUSED(info, window);
giuros0192fd9432018-12-03 17:30:00 +00001045 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1046 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
George Wortd88590f2018-12-12 17:39:58 +00001047 ARM_COMPUTE_ERROR_ON(_function == nullptr);
1048 _function(_input1, _input2, _output, window);
giuros0192fd9432018-12-03 17:30:00 +00001049}
1050
1051/** Arithmetic operators (min, max, squared_diff) */
1052
1053void NEArithmeticOperationKernel::configure(ArithmeticOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
1054{
George Wortd88590f2018-12-12 17:39:58 +00001055 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1056 configure_common(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001057 switch(op)
1058 {
1059 case ArithmeticOperation::MAX:
George Wortd88590f2018-12-12 17:39:58 +00001060 _function = configure_arithm_func<ArithmeticOperation::MAX>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001061 break;
1062 case ArithmeticOperation::MIN:
George Wortd88590f2018-12-12 17:39:58 +00001063 _function = configure_arithm_func<ArithmeticOperation::MIN>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001064 break;
1065 case ArithmeticOperation::SQUARED_DIFF:
George Wortd88590f2018-12-12 17:39:58 +00001066 _function = configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001067 break;
giuros01d5134362019-05-14 16:12:53 +01001068 case ArithmeticOperation::PRELU:
1069 _function = configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
1070 break;
giuros0192fd9432018-12-03 17:30:00 +00001071 default:
1072 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1073 }
1074}
1075
George Wortd88590f2018-12-12 17:39:58 +00001076Status NEArithmeticOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1077{
1078 // Validate in case of configured output
1079 if(output.total_size() > 0)
1080 {
1081 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
1082 }
1083 return validate_arguments_common(input1, input2, output);
1084}
1085
giuros0192fd9432018-12-03 17:30:00 +00001086Status NEArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1087{
1088 ARM_COMPUTE_UNUSED(op);
1089 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
George Wortd88590f2018-12-12 17:39:58 +00001090 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
giuros0192fd9432018-12-03 17:30:00 +00001091 return Status{};
1092}
1093
George Worta1e7e282019-01-15 11:00:29 +00001094/** The division operator */
1095
1096void NEDivisionOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1097{
1098 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1099 configure_common(input1, input2, output);
1100 _function = configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output);
1101}
1102
1103Status NEDivisionOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1104{
1105 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1106 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1107}
1108
1109Status NEDivisionOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1110{
1111 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1112 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1113 return Status{};
1114}
1115
Usama Arif81e671e2019-05-13 13:33:14 +01001116/** The power operator */
1117void NEPowerOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1118{
1119 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1120 configure_common(input1, input2, output);
1121 _function = configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
1122}
1123
1124Status NEPowerOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1125{
1126 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1127 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1128}
1129
1130Status NEPowerOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1131{
1132 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1133 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1134 return Status{};
1135}
1136
George Wortd88590f2018-12-12 17:39:58 +00001137/** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
1138
1139void NEComparisonOperationKernel::configure(ComparisonOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
giuros0192fd9432018-12-03 17:30:00 +00001140{
George Wortd88590f2018-12-12 17:39:58 +00001141 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1142 configure_common(input1, input2, output);
1143 switch(op)
1144 {
1145 case ComparisonOperation::Equal:
1146 _function = configure_comp_func<ComparisonOperation::Equal>(input1, input2, output);
1147 break;
1148 case ComparisonOperation::NotEqual:
1149 _function = configure_comp_func<ComparisonOperation::NotEqual>(input1, input2, output);
1150 break;
1151 case ComparisonOperation::Greater:
1152 _function = configure_comp_func<ComparisonOperation::Greater>(input1, input2, output);
1153 break;
1154 case ComparisonOperation::GreaterEqual:
1155 _function = configure_comp_func<ComparisonOperation::GreaterEqual>(input1, input2, output);
1156 break;
1157 case ComparisonOperation::Less:
1158 _function = configure_comp_func<ComparisonOperation::Less>(input1, input2, output);
1159 break;
1160 case ComparisonOperation::LessEqual:
1161 _function = configure_comp_func<ComparisonOperation::LessEqual>(input1, input2, output);
1162 break;
1163 default:
1164 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1165 }
1166}
1167
1168Status NEComparisonOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1169{
1170 // Validate in case of configured output
1171 if(output.total_size() > 0)
1172 {
1173 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8);
1174 }
1175 return validate_arguments_common(input1, input2, output);
1176}
1177
1178Status NEComparisonOperationKernel::validate(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1179{
1180 ARM_COMPUTE_UNUSED(op);
1181 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1182 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1183 return Status{};
giuros0192fd9432018-12-03 17:30:00 +00001184}
1185} // namespace arm_compute