blob: 0579dc67f4107b60e5d699744a8d0701fa5d1fd5 [file] [log] [blame]
giuros0192fd9432018-12-03 17:30:00 +00001/*
morgolocka3598052019-12-31 12:20:47 +00002 * Copyright (c) 2018-2020 ARM Limited.
giuros0192fd9432018-12-03 17:30:00 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEElementwiseOperationKernel.h"
25
26#include "arm_compute/core/CPP/Validate.h"
giuros0192fd9432018-12-03 17:30:00 +000027#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
giuros0192fd9432018-12-03 17:30:00 +000029#include "arm_compute/core/NEON/NEAsymm.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/wrapper/wrapper.h"
giuros0192fd9432018-12-03 17:30:00 +000032
giuros0192fd9432018-12-03 17:30:00 +000033#include <arm_neon.h>
giuros0192fd9432018-12-03 17:30:00 +000034#include <map>
giuros0192fd9432018-12-03 17:30:00 +000035
36namespace arm_compute
37{
giuros0192fd9432018-12-03 17:30:00 +000038namespace
39{
40float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
41{
42 qasymm8x16_t x = vld1q_u8(input1_ptr);
43 const float32x4x4_t out =
44 {
45 {
46 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
47 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
48 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
49 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
50 }
51 };
52 return out;
53}
54
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000055float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
56{
57 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
58 const float32x4x4_t out =
59 {
60 {
61 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
62 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
63 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
64 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
65 }
66 };
67 return out;
68}
69
George Wortd88590f2018-12-12 17:39:58 +000070void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
71{
72 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
73 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
74 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
75}
76
77void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
78{
79 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
80 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
81 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
82}
83
giuros0192fd9432018-12-03 17:30:00 +000084void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
85{
86 int32x4x4_t out =
87 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +000088 {
89 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
90 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
91 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
92 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
93 }
giuros0192fd9432018-12-03 17:30:00 +000094 };
George Wortd88590f2018-12-12 17:39:58 +000095 store_quantized(output_ptr, out);
giuros0192fd9432018-12-03 17:30:00 +000096}
97
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000098void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
99{
100 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
101 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
102 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
103}
104
105void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
106{
107 int32x4x4_t out =
108 {
109 {
110 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
111 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
112 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
113 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
114 }
115 };
116 store_quantized_signed(output_ptr, out);
117}
118
giuros0192fd9432018-12-03 17:30:00 +0000119float32x4x4_t dup_quantized(qasymm8_t broadcast_value, int offset, float scale)
120{
121 const qasymm8x16_t broadcast_value_vec = vdupq_n_u8(broadcast_value);
122 const int32x4_t voffset = vdupq_n_s32(offset);
123 const float32x4_t vscale = vdupq_n_f32(scale);
124
125 const float32x4x4_t broadcast_vector =
126 {
127 {
128 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(broadcast_value_vec))))), voffset)), vscale),
129 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(broadcast_value_vec))))), voffset)), vscale),
130 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(broadcast_value_vec))))), voffset)), vscale),
131 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(broadcast_value_vec))))), voffset)), vscale),
132 }
133 };
134 return broadcast_vector;
135}
136
137template <ArithmeticOperation op, typename ScalarType>
George Wortd88590f2018-12-12 17:39:58 +0000138inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
giuros0192fd9432018-12-03 17:30:00 +0000139{
140 auto res = ScalarType(0);
141
142 switch(op)
143 {
144 case ArithmeticOperation::MAX:
145 res = std::max(a, b);
146 break;
147 case ArithmeticOperation::MIN:
148 res = std::min(a, b);
149 break;
150 case ArithmeticOperation::SQUARED_DIFF:
151 {
152 res = (a - b) * (a - b);
153 break;
154 }
giuros01d5134362019-05-14 16:12:53 +0100155 case ArithmeticOperation::PRELU:
156 {
157 res = (a > 0 ? a : a * b);
158 break;
159 }
George Worta1e7e282019-01-15 11:00:29 +0000160 case ArithmeticOperation::DIV:
161 {
162 res = a / b;
163 break;
164 }
Usama Arif81e671e2019-05-13 13:33:14 +0100165 case ArithmeticOperation::POWER:
166 {
167 res = std::pow(a, b);
168 break;
169 }
giuros0192fd9432018-12-03 17:30:00 +0000170 default:
171 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
172 }
173 return res;
174}
175
George Wortd88590f2018-12-12 17:39:58 +0000176template <ArithmeticOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100177inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000178{
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100179 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
George Wortd88590f2018-12-12 17:39:58 +0000180}
181
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000182template <ArithmeticOperation op>
183inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
184{
185 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
186}
187
giuros01d5134362019-05-14 16:12:53 +0100188template <ArithmeticOperation op, typename VectorType>
189inline typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
giuros0192fd9432018-12-03 17:30:00 +0000190{
giuros01d5134362019-05-14 16:12:53 +0100191 using vec_type = typename VectorType::type;
192 using scalar_type = typename VectorType::scalar_type;
193 using tag_type = typename VectorType::tag_type;
194
195 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
giuros0192fd9432018-12-03 17:30:00 +0000196
197 switch(op)
198 {
199 case ArithmeticOperation::MAX:
200 res = wrapper::vmax(a, b);
201 break;
202 case ArithmeticOperation::MIN:
203 res = wrapper::vmin(a, b);
204 break;
205 case ArithmeticOperation::SQUARED_DIFF:
206 {
giuros01d5134362019-05-14 16:12:53 +0100207 const vec_type tmp = wrapper::vsub(a, b);
208 res = wrapper::vmul(tmp, tmp);
giuros0192fd9432018-12-03 17:30:00 +0000209 break;
210 }
giuros01d5134362019-05-14 16:12:53 +0100211 case ArithmeticOperation::PRELU:
212 {
213 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
214 const vec_type tmp = wrapper::vmul(a, b);
215 const auto gt = wrapper::vcgt(a, zero);
216
217 res = wrapper::vbsl(gt, a, tmp);
218 break;
219 }
220
giuros0192fd9432018-12-03 17:30:00 +0000221 default:
222 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
223 }
224
225 return res;
226}
227
George Worta1e7e282019-01-15 11:00:29 +0000228template <>
giuros01d5134362019-05-14 16:12:53 +0100229inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000230{
231 return wrapper::vdiv(a, b);
232}
233
Usama Arif81e671e2019-05-13 13:33:14 +0100234template <>
giuros01d5134362019-05-14 16:12:53 +0100235inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100236{
237 return wrapper::vpow(a, b);
238}
239
George Worta1e7e282019-01-15 11:00:29 +0000240#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
241template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100242inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000243{
244 return wrapper::vdiv(a, b);
245}
Usama Arif81e671e2019-05-13 13:33:14 +0100246
247template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100248inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100249{
250 return wrapper::vpow(a, b);
251}
George Worta1e7e282019-01-15 11:00:29 +0000252#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
253
giuros0192fd9432018-12-03 17:30:00 +0000254template <ArithmeticOperation op>
George Wortd88590f2018-12-12 17:39:58 +0000255inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
giuros0192fd9432018-12-03 17:30:00 +0000256{
giuros01d5134362019-05-14 16:12:53 +0100257 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
giuros0192fd9432018-12-03 17:30:00 +0000258 float32x4x4_t out =
259 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000260 {
giuros01d5134362019-05-14 16:12:53 +0100261 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
262 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
263 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
264 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000265 }
giuros0192fd9432018-12-03 17:30:00 +0000266 };
267 return out;
268}
269
giuros01d5134362019-05-14 16:12:53 +0100270template <ArithmeticOperation op, typename ScalarType, typename VectorType>
271inline typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
George Wortd88590f2018-12-12 17:39:58 +0000272{
giuros01d5134362019-05-14 16:12:53 +0100273 using tag_type = typename VectorType::tag_type;
274 using vec_type = typename VectorType::type;
275
276 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
277 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
George Wortd88590f2018-12-12 17:39:58 +0000278}
279
280template <ComparisonOperation op, typename InputScalarType>
281inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
282{
283 bool res = false;
284
285 switch(op)
286 {
287 case ComparisonOperation::Equal:
288 res = (a == b);
289 break;
290 case ComparisonOperation::NotEqual:
291 res = (a != b);
292 break;
293 case ComparisonOperation::Greater:
294 res = (a > b);
295 break;
296 case ComparisonOperation::GreaterEqual:
297 res = (a >= b);
298 break;
299 case ComparisonOperation::Less:
300 res = (a < b);
301 break;
302 case ComparisonOperation::LessEqual:
303 res = (a <= b);
304 break;
305 default:
306 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
307 }
308 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
309}
310
311template <ComparisonOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100312inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000313{
314 ARM_COMPUTE_UNUSED(qinfo);
315 return elementwise_comp_op_scalar<op>(a, b);
316}
317
318template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
319inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
320{
321 OutputVectorType res = { 0, 0, 0, 0 };
322
323 switch(op)
324 {
325 case ComparisonOperation::Equal:
326 res = wrapper::vceq(a, b);
327 break;
328 case ComparisonOperation::NotEqual:
329 res = wrapper::vnot(wrapper::vceq(a, b));
330 break;
331 case ComparisonOperation::Greater:
332 res = wrapper::vcgt(a, b);
333 break;
334 case ComparisonOperation::GreaterEqual:
335 res = wrapper::vcge(a, b);
336 break;
337 case ComparisonOperation::Less:
338 res = wrapper::vcgt(b, a);
339 break;
340 case ComparisonOperation::LessEqual:
341 res = wrapper::vcge(b, a);
342 break;
343 default:
344 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
345 }
346
347 return res;
348}
349
350template <ComparisonOperation op>
351inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
352{
353 uint32x4x4_t out =
354 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000355 {
356 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
357 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
358 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
359 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
360 }
George Wortd88590f2018-12-12 17:39:58 +0000361 };
362 return out;
363}
364
365template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
366inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
367{
368 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
369 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
370}
371
372template <ArithmeticOperation op, typename ScalarType, typename VectorType>
373inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
374 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
375{
376 int x = window_start_x;
377 for(; x <= (window_end_x - window_step_x); x += window_step_x)
378 {
379 const auto a = wrapper::vloadq(input1_ptr + x);
380 const auto b = wrapper::vloadq(input2_ptr + x);
giuros01d5134362019-05-14 16:12:53 +0100381 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
George Wortd88590f2018-12-12 17:39:58 +0000382 }
383 return x;
384}
385
386template <ArithmeticOperation op>
387inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
388 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
389 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
390 float32x4_t voffseto, float32x4_t invvscaleo)
391{
392 int x = window_start_x;
393 for(; x <= (window_end_x - window_step_x); x += window_step_x)
394 {
395 // Get inputs and compute output
396 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
397 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
398 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
399 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
400 }
401 return x;
402}
403
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000404template <ArithmeticOperation op>
405inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
406 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
407 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
408 float32x4_t voffseto, float32x4_t invvscaleo)
409{
410 int x = window_start_x;
411 for(; x <= (window_end_x - window_step_x); x += window_step_x)
412 {
413 // Get inputs and compute output
414 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
415 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
416 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
417 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
418 }
419 return x;
420}
421
George Wortd88590f2018-12-12 17:39:58 +0000422template <ArithmeticOperation op, typename ScalarType, typename VectorType>
423inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
424 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
425{
426 int x = window_start_x;
427 for(; x <= (window_end_x - window_step_x); x += window_step_x)
428 {
429 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
giuros01d5134362019-05-14 16:12:53 +0100430 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
George Wortd88590f2018-12-12 17:39:58 +0000431 }
432 return x;
433}
434
435template <ArithmeticOperation op>
436inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
437 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
438 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
439 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
440{
441 int x = window_start_x;
442 for(; x <= (window_end_x - window_step_x); x += window_step_x)
443 {
444 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
445 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
446 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
447 }
448 return x;
449}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000450template <ArithmeticOperation op>
451inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
452 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
453 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
454 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
455{
456 int x = window_start_x;
457 for(; x <= (window_end_x - window_step_x); x += window_step_x)
458 {
459 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
460 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
461 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
462 }
463 return x;
464}
George Wortd88590f2018-12-12 17:39:58 +0000465
466template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
467inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
468 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
469{
470 int x = window_start_x;
471 for(; x <= (window_end_x - window_step_x); x += window_step_x)
472 {
473 const auto a = wrapper::vloadq(input1_ptr + x);
474 const auto b = wrapper::vloadq(input2_ptr + x);
475 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
476 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
477 }
478 return x;
479}
480
481template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
482inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
483 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
484{
485 int x = window_start_x;
486 for(; x <= (window_end_x - window_step_x); x += window_step_x)
487 {
488 auto a = wrapper::vloadq(input1_ptr + x);
489 auto b = wrapper::vloadq(input2_ptr + x);
490 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
491 a = wrapper::vloadq(input1_ptr + x + 4);
492 b = wrapper::vloadq(input2_ptr + x + 4);
493 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
494 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
495 }
496 if(x <= window_end_x - 4)
497 {
498 const auto a = wrapper::vloadq(input1_ptr + x);
499 const auto b = wrapper::vloadq(input2_ptr + x);
500 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
501 for(int i = 0; i < 4; i++)
502 {
503 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
504 }
505 x = +4;
506 }
507 return x;
508}
509
510template <ComparisonOperation op>
511inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
512 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
513 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
514 float32x4_t voffseto, float32x4_t invvscaleo)
515{
516 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
517 int x = window_start_x;
518 for(; x <= (window_end_x - window_step_x); x += window_step_x)
519 {
520 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
521 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
522 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
523 store_quantized(output_ptr + x, rf);
524 }
525 return x;
526}
527
morgolock74a16962020-01-15 11:40:49 +0000528template <ComparisonOperation op>
529inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
530 const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
531 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
532 float32x4_t voffseto, float32x4_t invvscaleo)
533{
534 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
535 int x = window_start_x;
536 for(; x <= (window_end_x - window_step_x); x += window_step_x)
537 {
538 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
539 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
540 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
541 store_quantized(output_ptr + x, rf);
542 }
543 return x;
544}
545
George Wortd88590f2018-12-12 17:39:58 +0000546template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
547inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
548 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
549{
550 int x = window_start_x;
551 for(; x <= (window_end_x - window_step_x); x += window_step_x)
552 {
553 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
554 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
555 }
556 return x;
557}
558
559template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
560inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
561 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
562{
563 int x = window_start_x;
564 for(; x <= (window_end_x - window_step_x); x += window_step_x)
565 {
566 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
567 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
568 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
569 }
570 if(x <= window_end_x - 4)
571 {
572 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
573 for(int i = 0; i < 4; i++)
574 {
575 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
576 }
577 x = +4;
578 }
579 return x;
580}
581
582template <ComparisonOperation op>
583inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
584 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
585 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
586 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
587{
588 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
589 int x = window_start_x;
590 for(; x <= (window_end_x - window_step_x); x += window_step_x)
591 {
592 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
593 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
594 store_quantized(output_ptr + x, rf);
595 }
596 return x;
597}
598
599template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
600void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
601 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
602 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
603 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
giuros0192fd9432018-12-03 17:30:00 +0000604{
605 // Create input windows
606 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
607 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
608
609 // Clear X Dimension on execution window as we handle manually
610 Window win = window;
611 win.set(Window::DimX, Window::Dimension(0, 1, 1));
612
Michalis Spyroue8c0c432019-01-22 11:08:31 +0000613 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
giuros0192fd9432018-12-03 17:30:00 +0000614 const auto window_start_x = static_cast<int>(window.x().start());
615 const auto window_end_x = static_cast<int>(window.x().end());
616 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
617
618 if(is_broadcast_across_x)
619 {
giuros0192fd9432018-12-03 17:30:00 +0000620 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
621 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
622 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
623 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
624 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
625
626 // Clear X Dimension on execution window as we handle manually
627 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
628
629 Iterator broadcast_input(broadcast_tensor, broadcast_win);
630 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
631 Iterator output(out, win);
632
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100633 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000634 {
George Wortd88590f2018-12-12 17:39:58 +0000635 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
636 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
637 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000638
George Wortd88590f2018-12-12 17:39:58 +0000639 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000640 for(; x < window_end_x; ++x)
641 {
642 const auto a = *(non_broadcast_input_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000643 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
giuros0192fd9432018-12-03 17:30:00 +0000644 }
645 },
646 broadcast_input, non_broadcast_input, output);
647 }
648 else
649 {
650 // Clear X Dimension on execution window as we handle manually
651 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
652 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
653
654 Iterator input1(in1, input1_win);
655 Iterator input2(in2, input2_win);
656 Iterator output(out, win);
657
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100658 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000659 {
George Wortd88590f2018-12-12 17:39:58 +0000660 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
661 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
662 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000663
George Wortd88590f2018-12-12 17:39:58 +0000664 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
giuros0192fd9432018-12-03 17:30:00 +0000665 for(; x < window_end_x; ++x)
666 {
667 const auto a = *(input1_ptr + x);
668 const auto b = *(input2_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000669 *(output_ptr + x) = (*scalar_func)(a, b);
giuros0192fd9432018-12-03 17:30:00 +0000670 }
giuros0192fd9432018-12-03 17:30:00 +0000671 },
672 input1, input2, output);
673 }
674}
675
George Wortd88590f2018-12-12 17:39:58 +0000676void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100677 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
George Wortd88590f2018-12-12 17:39:58 +0000678 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
679 float32x4_t, float32x4_t, const bool),
680 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
681 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
682 float32x4_t, float32x4_t))
giuros0192fd9432018-12-03 17:30:00 +0000683{
684 // Create input windows
685 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
686 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
687
688 // Clear X Dimension on execution window as we handle manually
689 Window win = window;
690 win.set(Window::DimX, Window::Dimension(0, 1, 1));
691
692 const int window_step_x = 16;
693 const auto window_start_x = static_cast<int>(window.x().start());
694 const auto window_end_x = static_cast<int>(window.x().end());
695 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
696
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100697 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000698
699 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100700 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
701 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000702
703 if(is_broadcast_across_x)
704 {
705 // Select the broadcast input on the X axis
706 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
707 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
708 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
709 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
710 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
711
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100712 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
713 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000714
715 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
716 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
717
718 // Clear X Dimension on execution window as we handle manually
719 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
720
721 Iterator broadcast_input(broadcast_tensor, broadcast_win);
722 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
723 Iterator output(out, win);
724
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100725 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000726 {
727 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
728 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
729
730 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
731 const float32x4x4_t broadcast_vector = dup_quantized(broadcast_value, broadcast_qinfo.offset, broadcast_qinfo.scale);
732
George Wortd88590f2018-12-12 17:39:58 +0000733 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
734 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000735 for(; x < window_end_x; ++x)
736 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100737 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
738 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
739 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000740 }
741 },
742 broadcast_input, non_broadcast_input, output);
743 }
744 else
745 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100746 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
747 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
748
giuros0192fd9432018-12-03 17:30:00 +0000749 // Input1 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100750 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
751 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000752
753 // Input2 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100754 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
755 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000756
757 // Clear X Dimension on execution window as we handle manually
758 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
759 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
760
giuros0192fd9432018-12-03 17:30:00 +0000761 Iterator input1(in1, input1_win);
762 Iterator input2(in2, input2_win);
763 Iterator output(out, win);
764
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100765 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000766 {
767 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
768 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
769 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
770
George Wortd88590f2018-12-12 17:39:58 +0000771 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
772 vscale1, vscale2, voffseto, invvscaleo);
giuros0192fd9432018-12-03 17:30:00 +0000773 for(; x < window_end_x; ++x)
774 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100775 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
776 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
777 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000778 }
779 },
780 input1, input2, output);
781 }
782}
783
morgolock74a16962020-01-15 11:40:49 +0000784void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
785 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
786 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
787 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
788 float32x4_t, float32x4_t))
789{
790 // Create input windows
791 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
792 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
793
794 // Clear X Dimension on execution window as we handle manually
795 Window win = window;
796 win.set(Window::DimX, Window::Dimension(0, 1, 1));
797
798 const int window_step_x = 16;
799 const auto window_start_x = static_cast<int>(window.x().start());
800 const auto window_end_x = static_cast<int>(window.x().end());
801 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
802
803 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
804 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
805 {
806 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
807 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
808
809 // Input1 quantization info
810 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
811 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
812
813 // Input2 quantization info
814 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
815 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
816
817 // Clear X Dimension on execution window as we handle manually
818 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
819 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
820
821 Iterator input1(in1, input1_win);
822 Iterator input2(in2, input2_win);
823 Iterator output(out, win);
824
825 execute_window_loop(win, [&](const Coordinates &)
826 {
827 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
828 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
829 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
830
831 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
832 vscale1, vscale2, voffseto, invvscaleo);
833 for(; x < window_end_x; ++x)
834 {
835 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
836 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
837 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
838 }
839 },
840 input1, input2, output);
841 }
842}
843
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000844void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
845 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
846 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
847 float32x4_t, float32x4_t, const bool),
848 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
849 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
850 float32x4_t, float32x4_t))
851{
852 // Create input windows
853 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
854 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
855
856 // Clear X Dimension on execution window as we handle manually
857 Window win = window;
858 win.set(Window::DimX, Window::Dimension(0, 1, 1));
859
860 const int window_step_x = 16;
861 const auto window_start_x = static_cast<int>(window.x().start());
862 const auto window_end_x = static_cast<int>(window.x().end());
863 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
864
865 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
866
morgolocka3598052019-12-31 12:20:47 +0000867 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000868 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
869
870 if(is_broadcast_across_x)
871 {
872 // Select the broadcast input on the X axis
873 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
874 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
875 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
876 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
877 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
878
879 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
880 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
881
882 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
883 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
884
885 // Clear X Dimension on execution window as we handle manually
886 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
887
888 Iterator broadcast_input(broadcast_tensor, broadcast_win);
889 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
890 Iterator output(out, win);
891
892 execute_window_loop(win, [&](const Coordinates &)
893 {
894 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
895 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
896
897 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
898 const float32x4x4_t broadcast_vector = dup_quantized(broadcast_value, broadcast_qinfo.offset, broadcast_qinfo.scale);
899
900 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
901 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
902 for(; x < window_end_x; ++x)
903 {
904 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
905 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
906 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
907 }
908 },
909 broadcast_input, non_broadcast_input, output);
910 }
911 else
912 {
913 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
914 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
915
916 // Input1 quantization info
917 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
918 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
919
920 // Input2 quantization info
921 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
922 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
923
924 // Clear X Dimension on execution window as we handle manually
925 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
926 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
927
928 Iterator input1(in1, input1_win);
929 Iterator input2(in2, input2_win);
930 Iterator output(out, win);
931
932 execute_window_loop(win, [&](const Coordinates &)
933 {
934 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
935 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
936 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
937
938 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
939 vscale1, vscale2, voffseto, invvscaleo);
940 for(; x < window_end_x; ++x)
941 {
942 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
943 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
944 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
945 }
946 },
947 input1, input2, output);
948 }
949}
950
George Wortd88590f2018-12-12 17:39:58 +0000951template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
952void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
giuros0192fd9432018-12-03 17:30:00 +0000953{
George Wortd88590f2018-12-12 17:39:58 +0000954 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
955 &elementwise_comp_op_scalar<op, InputScalarType>,
956 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
957 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
958}
959
960template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
961void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
962{
963 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
964 &elementwise_comp_op_scalar<op, InputScalarType>,
965 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
966 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
967}
968
giuros01d5134362019-05-14 16:12:53 +0100969template <ArithmeticOperation op, typename VectorType>
George Wortd88590f2018-12-12 17:39:58 +0000970void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
971{
giuros01d5134362019-05-14 16:12:53 +0100972 using scalar_type = typename VectorType::scalar_type;
973
974 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
975 &elementwise_arithm_op_scalar<op, scalar_type>,
976 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
977 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
George Wortd88590f2018-12-12 17:39:58 +0000978}
979
980template <ArithmeticOperation op>
981void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
982{
983 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
984 &elementwise_arithm_op_quantized_broadcast_loop<op>,
985 &elementwise_arithm_op_quantized_loop<op>);
986}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000987template <ArithmeticOperation op>
988void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
989{
990 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
991 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
992 &elementwise_arithm_op_quantized_singed_loop<op>);
993}
George Wortd88590f2018-12-12 17:39:58 +0000994
995template <ComparisonOperation op>
996void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
997{
998 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
999 &elementwise_comp_op_quantized_broadcast_loop<op>,
1000 &elementwise_comp_op_quantized_loop<op>);
1001}
1002
morgolock74a16962020-01-15 11:40:49 +00001003template <ComparisonOperation op>
1004void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1005{
1006 elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>, &elementwise_comp_op_quantized_signed_loop<op>);
1007}
1008
George Wortd88590f2018-12-12 17:39:58 +00001009std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
1010configure_func(const ITensor *input1, const ITensor *input2, ITensor *output,
1011 std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function)
1012{
1013 std::string function_to_call("op_");
1014 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
1015 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
1016 function_to_call += string_from_data_type(output->info()->data_type());
1017
1018 auto it = map_function.find(function_to_call);
1019
1020 if(it != map_function.end())
1021 {
1022 auto func = it->second;
1023 return [func](const ITensor * input1, const ITensor * input2, ITensor * output, const Window & window)
1024 {
1025 func(input1, input2, output, window);
1026 };
1027 }
1028 return nullptr;
1029}
1030
1031template <ArithmeticOperation op>
1032std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
1033configure_arithm_func(const ITensor *input1, const ITensor *input2, ITensor *output)
1034{
1035 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1036 {
giuros01d5134362019-05-14 16:12:53 +01001037 { "op_F32_F32_F32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float, 4>> },
1038 { "op_S16_S16_S16", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int16_t, 8>> },
1039 { "op_S32_S32_S32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int32_t, 4>> },
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001040 { "op_QASYMM8_QASYMM8_QASYMM8", &elementwise_arithm_op_quantized<op> },
1041 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED", &elementwise_arithm_op_quantized_signed<op> }
George Wortd88590f2018-12-12 17:39:58 +00001042 };
1043#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
giuros01d5134362019-05-14 16:12:53 +01001044 map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float16_t, 8>>;
George Wortd88590f2018-12-12 17:39:58 +00001045#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1046
1047 return configure_func(input1, input2, output, map_function);
1048}
1049
1050template <ComparisonOperation op>
1051std::function<void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)>
1052configure_comp_func(const ITensor *input1, const ITensor *input2, ITensor *output)
1053{
1054 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1055 {
1056 { "op_F32_F32_U8", &elementwise_comp_op_32<op, float, float32x4_t> },
1057 { "op_S16_S16_U8", &elementwise_comp_op_16<op, int16_t, int16x8_t> },
1058 { "op_S32_S32_U8", &elementwise_comp_op_32<op, int32_t, int32x4_t> },
morgolock74a16962020-01-15 11:40:49 +00001059 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_U8", &elementwise_comp_op_quantized_signed<op> },
George Wortd88590f2018-12-12 17:39:58 +00001060 { "op_QASYMM8_QASYMM8_U8", &elementwise_comp_op_quantized<op> }
1061 };
1062#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1063 map_function["op_F16_F16_U8"] = &elementwise_comp_op_16<op, float16_t, float16x8_t>;
1064#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1065
1066 return configure_func(input1, input2, output, map_function);
1067}
1068} // namespace
1069
1070NEElementwiseOperationKernel::NEElementwiseOperationKernel()
1071 : _function(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
1072{
1073}
1074
1075Status NEElementwiseOperationKernel::validate_arguments_common(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1076{
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001077 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
George Wortd88590f2018-12-12 17:39:58 +00001078 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
giuros0192fd9432018-12-03 17:30:00 +00001079 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
1080
1081 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
1082
1083 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1084
1085 // Validate in case of configured output
1086 if(output.total_size() > 0)
1087 {
giuros0192fd9432018-12-03 17:30:00 +00001088 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
1089 "Wrong shape for output");
1090 }
1091
1092 return Status{};
1093}
giuros0192fd9432018-12-03 17:30:00 +00001094
giuros0192fd9432018-12-03 17:30:00 +00001095void NEElementwiseOperationKernel::configure_common(const ITensor *input1, const ITensor *input2, ITensor *output)
1096{
1097 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001098
1099 // Configure kernel window
1100 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1->info(), *input2->info());
1101 const TensorShape &out_shape = broadcast_pair.first;
1102 const ValidRegion &valid_region = broadcast_pair.second;
1103
1104 // Auto initialize output if not initialized
1105 auto_init_if_empty(*output->info(), out_shape, 1, input1->info()->data_type());
1106
1107 Window win = calculate_max_window(valid_region);
1108
giuros0192fd9432018-12-03 17:30:00 +00001109 _input1 = input1;
1110 _input2 = input2;
1111 _output = output;
1112
giuros0192fd9432018-12-03 17:30:00 +00001113 INEKernel::configure(win);
1114}
1115
1116void NEElementwiseOperationKernel::run(const Window &window, const ThreadInfo &info)
1117{
George Wortd88590f2018-12-12 17:39:58 +00001118 ARM_COMPUTE_UNUSED(info, window);
giuros0192fd9432018-12-03 17:30:00 +00001119 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1120 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
George Wortd88590f2018-12-12 17:39:58 +00001121 ARM_COMPUTE_ERROR_ON(_function == nullptr);
1122 _function(_input1, _input2, _output, window);
giuros0192fd9432018-12-03 17:30:00 +00001123}
1124
1125/** Arithmetic operators (min, max, squared_diff) */
1126
1127void NEArithmeticOperationKernel::configure(ArithmeticOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
1128{
George Wortd88590f2018-12-12 17:39:58 +00001129 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1130 configure_common(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001131 switch(op)
1132 {
1133 case ArithmeticOperation::MAX:
George Wortd88590f2018-12-12 17:39:58 +00001134 _function = configure_arithm_func<ArithmeticOperation::MAX>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001135 break;
1136 case ArithmeticOperation::MIN:
George Wortd88590f2018-12-12 17:39:58 +00001137 _function = configure_arithm_func<ArithmeticOperation::MIN>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001138 break;
1139 case ArithmeticOperation::SQUARED_DIFF:
George Wortd88590f2018-12-12 17:39:58 +00001140 _function = configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001141 break;
giuros01d5134362019-05-14 16:12:53 +01001142 case ArithmeticOperation::PRELU:
1143 _function = configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
1144 break;
giuros0192fd9432018-12-03 17:30:00 +00001145 default:
1146 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1147 }
1148}
1149
George Wortd88590f2018-12-12 17:39:58 +00001150Status NEArithmeticOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1151{
1152 // Validate in case of configured output
1153 if(output.total_size() > 0)
1154 {
1155 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
1156 }
1157 return validate_arguments_common(input1, input2, output);
1158}
1159
giuros0192fd9432018-12-03 17:30:00 +00001160Status NEArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1161{
1162 ARM_COMPUTE_UNUSED(op);
1163 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
George Wortd88590f2018-12-12 17:39:58 +00001164 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
giuros0192fd9432018-12-03 17:30:00 +00001165 return Status{};
1166}
1167
George Worta1e7e282019-01-15 11:00:29 +00001168/** The division operator */
1169
1170void NEDivisionOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1171{
1172 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1173 configure_common(input1, input2, output);
1174 _function = configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output);
1175}
1176
1177Status NEDivisionOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1178{
1179 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1180 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1181}
1182
1183Status NEDivisionOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1184{
1185 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1186 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1187 return Status{};
1188}
1189
Usama Arif81e671e2019-05-13 13:33:14 +01001190/** The power operator */
1191void NEPowerOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1192{
1193 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1194 configure_common(input1, input2, output);
1195 _function = configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
1196}
1197
1198Status NEPowerOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1199{
1200 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1201 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1202}
1203
1204Status NEPowerOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1205{
1206 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1207 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1208 return Status{};
1209}
1210
George Wortd88590f2018-12-12 17:39:58 +00001211/** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
1212
1213void NEComparisonOperationKernel::configure(ComparisonOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
giuros0192fd9432018-12-03 17:30:00 +00001214{
George Wortd88590f2018-12-12 17:39:58 +00001215 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1216 configure_common(input1, input2, output);
1217 switch(op)
1218 {
1219 case ComparisonOperation::Equal:
1220 _function = configure_comp_func<ComparisonOperation::Equal>(input1, input2, output);
1221 break;
1222 case ComparisonOperation::NotEqual:
1223 _function = configure_comp_func<ComparisonOperation::NotEqual>(input1, input2, output);
1224 break;
1225 case ComparisonOperation::Greater:
1226 _function = configure_comp_func<ComparisonOperation::Greater>(input1, input2, output);
1227 break;
1228 case ComparisonOperation::GreaterEqual:
1229 _function = configure_comp_func<ComparisonOperation::GreaterEqual>(input1, input2, output);
1230 break;
1231 case ComparisonOperation::Less:
1232 _function = configure_comp_func<ComparisonOperation::Less>(input1, input2, output);
1233 break;
1234 case ComparisonOperation::LessEqual:
1235 _function = configure_comp_func<ComparisonOperation::LessEqual>(input1, input2, output);
1236 break;
1237 default:
1238 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1239 }
1240}
1241
1242Status NEComparisonOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1243{
1244 // Validate in case of configured output
1245 if(output.total_size() > 0)
1246 {
1247 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8);
1248 }
1249 return validate_arguments_common(input1, input2, output);
1250}
1251
1252Status NEComparisonOperationKernel::validate(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1253{
1254 ARM_COMPUTE_UNUSED(op);
1255 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1256 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1257 return Status{};
giuros0192fd9432018-12-03 17:30:00 +00001258}
1259} // namespace arm_compute