blob: 444ee8e0d3752fca7771a174bf9dd3d9716db01f [file] [log] [blame]
giuros0192fd9432018-12-03 17:30:00 +00001/*
morgolocka3598052019-12-31 12:20:47 +00002 * Copyright (c) 2018-2020 ARM Limited.
giuros0192fd9432018-12-03 17:30:00 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEElementwiseOperationKernel.h"
25
26#include "arm_compute/core/CPP/Validate.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEAsymm.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/NEON/wrapper/wrapper.h"
34#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Validate.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39#include <cstdint>
40#include <map>
41#include <string>
42
43namespace arm_compute
44{
45class Coordinates;
46
47namespace
48{
49float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
50{
51 qasymm8x16_t x = vld1q_u8(input1_ptr);
52 const float32x4x4_t out =
53 {
54 {
55 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
56 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
57 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
58 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
59 }
60 };
61 return out;
62}
63
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000064float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
65{
66 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
67 const float32x4x4_t out =
68 {
69 {
70 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
71 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
72 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
73 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
74 }
75 };
76 return out;
77}
78
George Wortd88590f2018-12-12 17:39:58 +000079void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
80{
81 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
82 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
83 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
84}
85
86void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
87{
88 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
89 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
90 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
91}
92
giuros0192fd9432018-12-03 17:30:00 +000093void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
94{
95 int32x4x4_t out =
96 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +000097 {
98 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
99 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
100 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
101 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
102 }
giuros0192fd9432018-12-03 17:30:00 +0000103 };
George Wortd88590f2018-12-12 17:39:58 +0000104 store_quantized(output_ptr, out);
giuros0192fd9432018-12-03 17:30:00 +0000105}
106
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000107void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
108{
109 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
110 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
111 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
112}
113
114void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
115{
116 int32x4x4_t out =
117 {
118 {
119 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
120 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
121 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
122 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
123 }
124 };
125 store_quantized_signed(output_ptr, out);
126}
127
giuros0192fd9432018-12-03 17:30:00 +0000128float32x4x4_t dup_quantized(qasymm8_t broadcast_value, int offset, float scale)
129{
130 const qasymm8x16_t broadcast_value_vec = vdupq_n_u8(broadcast_value);
131 const int32x4_t voffset = vdupq_n_s32(offset);
132 const float32x4_t vscale = vdupq_n_f32(scale);
133
134 const float32x4x4_t broadcast_vector =
135 {
136 {
137 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(broadcast_value_vec))))), voffset)), vscale),
138 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(broadcast_value_vec))))), voffset)), vscale),
139 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(broadcast_value_vec))))), voffset)), vscale),
140 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(broadcast_value_vec))))), voffset)), vscale),
141 }
142 };
143 return broadcast_vector;
144}
145
146template <ArithmeticOperation op, typename ScalarType>
George Wortd88590f2018-12-12 17:39:58 +0000147inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
giuros0192fd9432018-12-03 17:30:00 +0000148{
149 auto res = ScalarType(0);
150
151 switch(op)
152 {
153 case ArithmeticOperation::MAX:
154 res = std::max(a, b);
155 break;
156 case ArithmeticOperation::MIN:
157 res = std::min(a, b);
158 break;
159 case ArithmeticOperation::SQUARED_DIFF:
160 {
161 res = (a - b) * (a - b);
162 break;
163 }
giuros01d5134362019-05-14 16:12:53 +0100164 case ArithmeticOperation::PRELU:
165 {
166 res = (a > 0 ? a : a * b);
167 break;
168 }
George Worta1e7e282019-01-15 11:00:29 +0000169 case ArithmeticOperation::DIV:
170 {
171 res = a / b;
172 break;
173 }
Usama Arif81e671e2019-05-13 13:33:14 +0100174 case ArithmeticOperation::POWER:
175 {
176 res = std::pow(a, b);
177 break;
178 }
giuros0192fd9432018-12-03 17:30:00 +0000179 default:
180 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
181 }
182 return res;
183}
184
George Wortd88590f2018-12-12 17:39:58 +0000185template <ArithmeticOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100186inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000187{
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100188 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
George Wortd88590f2018-12-12 17:39:58 +0000189}
190
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000191template <ArithmeticOperation op>
192inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
193{
194 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
195}
196
giuros01d5134362019-05-14 16:12:53 +0100197template <ArithmeticOperation op, typename VectorType>
198inline typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
giuros0192fd9432018-12-03 17:30:00 +0000199{
giuros01d5134362019-05-14 16:12:53 +0100200 using vec_type = typename VectorType::type;
201 using scalar_type = typename VectorType::scalar_type;
202 using tag_type = typename VectorType::tag_type;
203
204 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
giuros0192fd9432018-12-03 17:30:00 +0000205
206 switch(op)
207 {
208 case ArithmeticOperation::MAX:
209 res = wrapper::vmax(a, b);
210 break;
211 case ArithmeticOperation::MIN:
212 res = wrapper::vmin(a, b);
213 break;
214 case ArithmeticOperation::SQUARED_DIFF:
215 {
giuros01d5134362019-05-14 16:12:53 +0100216 const vec_type tmp = wrapper::vsub(a, b);
217 res = wrapper::vmul(tmp, tmp);
giuros0192fd9432018-12-03 17:30:00 +0000218 break;
219 }
giuros01d5134362019-05-14 16:12:53 +0100220 case ArithmeticOperation::PRELU:
221 {
222 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
223 const vec_type tmp = wrapper::vmul(a, b);
224 const auto gt = wrapper::vcgt(a, zero);
225
226 res = wrapper::vbsl(gt, a, tmp);
227 break;
228 }
229
giuros0192fd9432018-12-03 17:30:00 +0000230 default:
231 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
232 }
233
234 return res;
235}
236
George Worta1e7e282019-01-15 11:00:29 +0000237template <>
giuros01d5134362019-05-14 16:12:53 +0100238inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000239{
240 return wrapper::vdiv(a, b);
241}
242
Usama Arif81e671e2019-05-13 13:33:14 +0100243template <>
giuros01d5134362019-05-14 16:12:53 +0100244inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100245{
246 return wrapper::vpow(a, b);
247}
248
George Worta1e7e282019-01-15 11:00:29 +0000249#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
250template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100251inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000252{
253 return wrapper::vdiv(a, b);
254}
Usama Arif81e671e2019-05-13 13:33:14 +0100255
256template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100257inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100258{
259 return wrapper::vpow(a, b);
260}
George Worta1e7e282019-01-15 11:00:29 +0000261#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
262
giuros0192fd9432018-12-03 17:30:00 +0000263template <ArithmeticOperation op>
George Wortd88590f2018-12-12 17:39:58 +0000264inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
giuros0192fd9432018-12-03 17:30:00 +0000265{
giuros01d5134362019-05-14 16:12:53 +0100266 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
giuros0192fd9432018-12-03 17:30:00 +0000267 float32x4x4_t out =
268 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000269 {
giuros01d5134362019-05-14 16:12:53 +0100270 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
271 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
272 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
273 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000274 }
giuros0192fd9432018-12-03 17:30:00 +0000275 };
276 return out;
277}
278
giuros01d5134362019-05-14 16:12:53 +0100279template <ArithmeticOperation op, typename ScalarType, typename VectorType>
280inline typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
George Wortd88590f2018-12-12 17:39:58 +0000281{
giuros01d5134362019-05-14 16:12:53 +0100282 using tag_type = typename VectorType::tag_type;
283 using vec_type = typename VectorType::type;
284
285 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
286 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
George Wortd88590f2018-12-12 17:39:58 +0000287}
288
289template <ComparisonOperation op, typename InputScalarType>
290inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
291{
292 bool res = false;
293
294 switch(op)
295 {
296 case ComparisonOperation::Equal:
297 res = (a == b);
298 break;
299 case ComparisonOperation::NotEqual:
300 res = (a != b);
301 break;
302 case ComparisonOperation::Greater:
303 res = (a > b);
304 break;
305 case ComparisonOperation::GreaterEqual:
306 res = (a >= b);
307 break;
308 case ComparisonOperation::Less:
309 res = (a < b);
310 break;
311 case ComparisonOperation::LessEqual:
312 res = (a <= b);
313 break;
314 default:
315 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
316 }
317 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
318}
319
320template <ComparisonOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100321inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000322{
323 ARM_COMPUTE_UNUSED(qinfo);
324 return elementwise_comp_op_scalar<op>(a, b);
325}
326
327template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
328inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
329{
330 OutputVectorType res = { 0, 0, 0, 0 };
331
332 switch(op)
333 {
334 case ComparisonOperation::Equal:
335 res = wrapper::vceq(a, b);
336 break;
337 case ComparisonOperation::NotEqual:
338 res = wrapper::vnot(wrapper::vceq(a, b));
339 break;
340 case ComparisonOperation::Greater:
341 res = wrapper::vcgt(a, b);
342 break;
343 case ComparisonOperation::GreaterEqual:
344 res = wrapper::vcge(a, b);
345 break;
346 case ComparisonOperation::Less:
347 res = wrapper::vcgt(b, a);
348 break;
349 case ComparisonOperation::LessEqual:
350 res = wrapper::vcge(b, a);
351 break;
352 default:
353 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
354 }
355
356 return res;
357}
358
359template <ComparisonOperation op>
360inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
361{
362 uint32x4x4_t out =
363 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000364 {
365 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
366 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
367 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
368 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
369 }
George Wortd88590f2018-12-12 17:39:58 +0000370 };
371 return out;
372}
373
374template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
375inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
376{
377 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
378 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
379}
380
381template <ArithmeticOperation op, typename ScalarType, typename VectorType>
382inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
383 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
384{
385 int x = window_start_x;
386 for(; x <= (window_end_x - window_step_x); x += window_step_x)
387 {
388 const auto a = wrapper::vloadq(input1_ptr + x);
389 const auto b = wrapper::vloadq(input2_ptr + x);
giuros01d5134362019-05-14 16:12:53 +0100390 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
George Wortd88590f2018-12-12 17:39:58 +0000391 }
392 return x;
393}
394
395template <ArithmeticOperation op>
396inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
397 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
398 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
399 float32x4_t voffseto, float32x4_t invvscaleo)
400{
401 int x = window_start_x;
402 for(; x <= (window_end_x - window_step_x); x += window_step_x)
403 {
404 // Get inputs and compute output
405 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
406 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
407 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
408 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
409 }
410 return x;
411}
412
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000413template <ArithmeticOperation op>
414inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
415 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
416 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
417 float32x4_t voffseto, float32x4_t invvscaleo)
418{
419 int x = window_start_x;
420 for(; x <= (window_end_x - window_step_x); x += window_step_x)
421 {
422 // Get inputs and compute output
423 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
424 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
425 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
426 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
427 }
428 return x;
429}
430
George Wortd88590f2018-12-12 17:39:58 +0000431template <ArithmeticOperation op, typename ScalarType, typename VectorType>
432inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
433 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
434{
435 int x = window_start_x;
436 for(; x <= (window_end_x - window_step_x); x += window_step_x)
437 {
438 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
giuros01d5134362019-05-14 16:12:53 +0100439 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
George Wortd88590f2018-12-12 17:39:58 +0000440 }
441 return x;
442}
443
444template <ArithmeticOperation op>
445inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
446 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
447 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
448 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
449{
450 int x = window_start_x;
451 for(; x <= (window_end_x - window_step_x); x += window_step_x)
452 {
453 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
454 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
455 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
456 }
457 return x;
458}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000459template <ArithmeticOperation op>
460inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
461 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
462 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
463 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
464{
465 int x = window_start_x;
466 for(; x <= (window_end_x - window_step_x); x += window_step_x)
467 {
468 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
469 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
470 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
471 }
472 return x;
473}
George Wortd88590f2018-12-12 17:39:58 +0000474
475template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
476inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
477 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
478{
479 int x = window_start_x;
480 for(; x <= (window_end_x - window_step_x); x += window_step_x)
481 {
482 const auto a = wrapper::vloadq(input1_ptr + x);
483 const auto b = wrapper::vloadq(input2_ptr + x);
484 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
485 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
486 }
487 return x;
488}
489
490template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
491inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
492 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
493{
494 int x = window_start_x;
495 for(; x <= (window_end_x - window_step_x); x += window_step_x)
496 {
497 auto a = wrapper::vloadq(input1_ptr + x);
498 auto b = wrapper::vloadq(input2_ptr + x);
499 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
500 a = wrapper::vloadq(input1_ptr + x + 4);
501 b = wrapper::vloadq(input2_ptr + x + 4);
502 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
503 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
504 }
505 if(x <= window_end_x - 4)
506 {
507 const auto a = wrapper::vloadq(input1_ptr + x);
508 const auto b = wrapper::vloadq(input2_ptr + x);
509 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
510 for(int i = 0; i < 4; i++)
511 {
512 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
513 }
514 x = +4;
515 }
516 return x;
517}
518
519template <ComparisonOperation op>
520inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
521 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
522 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
523 float32x4_t voffseto, float32x4_t invvscaleo)
524{
525 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
526 int x = window_start_x;
527 for(; x <= (window_end_x - window_step_x); x += window_step_x)
528 {
529 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
530 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
531 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
532 store_quantized(output_ptr + x, rf);
533 }
534 return x;
535}
536
morgolock74a16962020-01-15 11:40:49 +0000537template <ComparisonOperation op>
538inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
539 const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
540 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
541 float32x4_t voffseto, float32x4_t invvscaleo)
542{
543 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
544 int x = window_start_x;
545 for(; x <= (window_end_x - window_step_x); x += window_step_x)
546 {
547 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
548 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
549 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
550 store_quantized(output_ptr + x, rf);
551 }
552 return x;
553}
554
George Wortd88590f2018-12-12 17:39:58 +0000555template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
556inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
557 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
558{
559 int x = window_start_x;
560 for(; x <= (window_end_x - window_step_x); x += window_step_x)
561 {
562 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
563 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
564 }
565 return x;
566}
567
568template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
569inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
570 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
571{
572 int x = window_start_x;
573 for(; x <= (window_end_x - window_step_x); x += window_step_x)
574 {
575 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
576 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
577 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
578 }
579 if(x <= window_end_x - 4)
580 {
581 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
582 for(int i = 0; i < 4; i++)
583 {
584 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
585 }
586 x = +4;
587 }
588 return x;
589}
590
591template <ComparisonOperation op>
592inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
593 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
594 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
595 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
596{
597 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
598 int x = window_start_x;
599 for(; x <= (window_end_x - window_step_x); x += window_step_x)
600 {
601 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
602 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
603 store_quantized(output_ptr + x, rf);
604 }
605 return x;
606}
607
608template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
609void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
610 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
611 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
612 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
giuros0192fd9432018-12-03 17:30:00 +0000613{
614 // Create input windows
615 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
616 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
617
618 // Clear X Dimension on execution window as we handle manually
619 Window win = window;
620 win.set(Window::DimX, Window::Dimension(0, 1, 1));
621
Michalis Spyroue8c0c432019-01-22 11:08:31 +0000622 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
giuros0192fd9432018-12-03 17:30:00 +0000623 const auto window_start_x = static_cast<int>(window.x().start());
624 const auto window_end_x = static_cast<int>(window.x().end());
625 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
626
627 if(is_broadcast_across_x)
628 {
giuros0192fd9432018-12-03 17:30:00 +0000629 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
630 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
631 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
632 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
633 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
634
635 // Clear X Dimension on execution window as we handle manually
636 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
637
638 Iterator broadcast_input(broadcast_tensor, broadcast_win);
639 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
640 Iterator output(out, win);
641
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100642 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000643 {
George Wortd88590f2018-12-12 17:39:58 +0000644 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
645 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
646 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000647
George Wortd88590f2018-12-12 17:39:58 +0000648 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000649 for(; x < window_end_x; ++x)
650 {
651 const auto a = *(non_broadcast_input_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000652 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
giuros0192fd9432018-12-03 17:30:00 +0000653 }
654 },
655 broadcast_input, non_broadcast_input, output);
656 }
657 else
658 {
659 // Clear X Dimension on execution window as we handle manually
660 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
661 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
662
663 Iterator input1(in1, input1_win);
664 Iterator input2(in2, input2_win);
665 Iterator output(out, win);
666
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100667 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000668 {
George Wortd88590f2018-12-12 17:39:58 +0000669 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
670 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
671 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000672
George Wortd88590f2018-12-12 17:39:58 +0000673 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
giuros0192fd9432018-12-03 17:30:00 +0000674 for(; x < window_end_x; ++x)
675 {
676 const auto a = *(input1_ptr + x);
677 const auto b = *(input2_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000678 *(output_ptr + x) = (*scalar_func)(a, b);
giuros0192fd9432018-12-03 17:30:00 +0000679 }
giuros0192fd9432018-12-03 17:30:00 +0000680 },
681 input1, input2, output);
682 }
683}
684
George Wortd88590f2018-12-12 17:39:58 +0000685void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100686 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
George Wortd88590f2018-12-12 17:39:58 +0000687 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
688 float32x4_t, float32x4_t, const bool),
689 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
690 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
691 float32x4_t, float32x4_t))
giuros0192fd9432018-12-03 17:30:00 +0000692{
693 // Create input windows
694 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
695 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
696
697 // Clear X Dimension on execution window as we handle manually
698 Window win = window;
699 win.set(Window::DimX, Window::Dimension(0, 1, 1));
700
701 const int window_step_x = 16;
702 const auto window_start_x = static_cast<int>(window.x().start());
703 const auto window_end_x = static_cast<int>(window.x().end());
704 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
705
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100706 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000707
708 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100709 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
710 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000711
712 if(is_broadcast_across_x)
713 {
714 // Select the broadcast input on the X axis
715 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
716 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
717 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
718 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
719 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
720
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100721 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
722 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000723
724 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
725 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
726
727 // Clear X Dimension on execution window as we handle manually
728 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
729
730 Iterator broadcast_input(broadcast_tensor, broadcast_win);
731 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
732 Iterator output(out, win);
733
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100734 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000735 {
736 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
737 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
738
739 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
740 const float32x4x4_t broadcast_vector = dup_quantized(broadcast_value, broadcast_qinfo.offset, broadcast_qinfo.scale);
741
George Wortd88590f2018-12-12 17:39:58 +0000742 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
743 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000744 for(; x < window_end_x; ++x)
745 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100746 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
747 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
748 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000749 }
750 },
751 broadcast_input, non_broadcast_input, output);
752 }
753 else
754 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100755 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
756 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
757
giuros0192fd9432018-12-03 17:30:00 +0000758 // Input1 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100759 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
760 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000761
762 // Input2 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100763 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
764 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000765
766 // Clear X Dimension on execution window as we handle manually
767 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
768 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
769
giuros0192fd9432018-12-03 17:30:00 +0000770 Iterator input1(in1, input1_win);
771 Iterator input2(in2, input2_win);
772 Iterator output(out, win);
773
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100774 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000775 {
776 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
777 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
778 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
779
George Wortd88590f2018-12-12 17:39:58 +0000780 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
781 vscale1, vscale2, voffseto, invvscaleo);
giuros0192fd9432018-12-03 17:30:00 +0000782 for(; x < window_end_x; ++x)
783 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100784 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
785 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
786 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000787 }
788 },
789 input1, input2, output);
790 }
791}
792
morgolock74a16962020-01-15 11:40:49 +0000793void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
794 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
795 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
796 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
797 float32x4_t, float32x4_t))
798{
799 // Create input windows
800 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
801 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
802
803 // Clear X Dimension on execution window as we handle manually
804 Window win = window;
805 win.set(Window::DimX, Window::Dimension(0, 1, 1));
806
807 const int window_step_x = 16;
808 const auto window_start_x = static_cast<int>(window.x().start());
809 const auto window_end_x = static_cast<int>(window.x().end());
810 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
811
812 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
813 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
814 {
815 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
816 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
817
818 // Input1 quantization info
819 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
820 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
821
822 // Input2 quantization info
823 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
824 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
825
826 // Clear X Dimension on execution window as we handle manually
827 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
828 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
829
830 Iterator input1(in1, input1_win);
831 Iterator input2(in2, input2_win);
832 Iterator output(out, win);
833
834 execute_window_loop(win, [&](const Coordinates &)
835 {
836 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
837 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
838 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
839
840 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
841 vscale1, vscale2, voffseto, invvscaleo);
842 for(; x < window_end_x; ++x)
843 {
844 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
845 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
846 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
847 }
848 },
849 input1, input2, output);
850 }
851}
852
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000853void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
854 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
855 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
856 float32x4_t, float32x4_t, const bool),
857 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
858 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
859 float32x4_t, float32x4_t))
860{
861 // Create input windows
862 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
863 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
864
865 // Clear X Dimension on execution window as we handle manually
866 Window win = window;
867 win.set(Window::DimX, Window::Dimension(0, 1, 1));
868
869 const int window_step_x = 16;
870 const auto window_start_x = static_cast<int>(window.x().start());
871 const auto window_end_x = static_cast<int>(window.x().end());
872 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
873
874 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
875
morgolocka3598052019-12-31 12:20:47 +0000876 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000877 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
878
879 if(is_broadcast_across_x)
880 {
881 // Select the broadcast input on the X axis
882 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
883 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
884 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
885 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
886 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
887
888 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
889 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
890
891 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
892 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
893
894 // Clear X Dimension on execution window as we handle manually
895 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
896
897 Iterator broadcast_input(broadcast_tensor, broadcast_win);
898 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
899 Iterator output(out, win);
900
901 execute_window_loop(win, [&](const Coordinates &)
902 {
903 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
904 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
905
906 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
907 const float32x4x4_t broadcast_vector = dup_quantized(broadcast_value, broadcast_qinfo.offset, broadcast_qinfo.scale);
908
909 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
910 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
911 for(; x < window_end_x; ++x)
912 {
913 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
914 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
915 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
916 }
917 },
918 broadcast_input, non_broadcast_input, output);
919 }
920 else
921 {
922 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
923 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
924
925 // Input1 quantization info
926 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
927 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
928
929 // Input2 quantization info
930 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
931 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
932
933 // Clear X Dimension on execution window as we handle manually
934 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
935 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
936
937 Iterator input1(in1, input1_win);
938 Iterator input2(in2, input2_win);
939 Iterator output(out, win);
940
941 execute_window_loop(win, [&](const Coordinates &)
942 {
943 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
944 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
945 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
946
947 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
948 vscale1, vscale2, voffseto, invvscaleo);
949 for(; x < window_end_x; ++x)
950 {
951 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
952 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
953 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
954 }
955 },
956 input1, input2, output);
957 }
958}
959
George Wortd88590f2018-12-12 17:39:58 +0000960template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
961void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
giuros0192fd9432018-12-03 17:30:00 +0000962{
George Wortd88590f2018-12-12 17:39:58 +0000963 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
964 &elementwise_comp_op_scalar<op, InputScalarType>,
965 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
966 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
967}
968
969template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
970void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
971{
972 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
973 &elementwise_comp_op_scalar<op, InputScalarType>,
974 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
975 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
976}
977
giuros01d5134362019-05-14 16:12:53 +0100978template <ArithmeticOperation op, typename VectorType>
George Wortd88590f2018-12-12 17:39:58 +0000979void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
980{
giuros01d5134362019-05-14 16:12:53 +0100981 using scalar_type = typename VectorType::scalar_type;
982
983 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
984 &elementwise_arithm_op_scalar<op, scalar_type>,
985 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
986 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
George Wortd88590f2018-12-12 17:39:58 +0000987}
988
989template <ArithmeticOperation op>
990void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
991{
992 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
993 &elementwise_arithm_op_quantized_broadcast_loop<op>,
994 &elementwise_arithm_op_quantized_loop<op>);
995}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000996template <ArithmeticOperation op>
997void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
998{
999 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
1000 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1001 &elementwise_arithm_op_quantized_singed_loop<op>);
1002}
George Wortd88590f2018-12-12 17:39:58 +00001003
1004template <ComparisonOperation op>
1005void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1006{
1007 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1008 &elementwise_comp_op_quantized_broadcast_loop<op>,
1009 &elementwise_comp_op_quantized_loop<op>);
1010}
1011
morgolock74a16962020-01-15 11:40:49 +00001012template <ComparisonOperation op>
1013void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1014{
1015 elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>, &elementwise_comp_op_quantized_signed_loop<op>);
1016}
1017
George Wortd88590f2018-12-12 17:39:58 +00001018std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
1019configure_func(const ITensor *input1, const ITensor *input2, ITensor *output,
1020 std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function)
1021{
1022 std::string function_to_call("op_");
1023 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
1024 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
1025 function_to_call += string_from_data_type(output->info()->data_type());
1026
1027 auto it = map_function.find(function_to_call);
1028
1029 if(it != map_function.end())
1030 {
1031 auto func = it->second;
1032 return [func](const ITensor * input1, const ITensor * input2, ITensor * output, const Window & window)
1033 {
1034 func(input1, input2, output, window);
1035 };
1036 }
1037 return nullptr;
1038}
1039
1040template <ArithmeticOperation op>
1041std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
1042configure_arithm_func(const ITensor *input1, const ITensor *input2, ITensor *output)
1043{
1044 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1045 {
giuros01d5134362019-05-14 16:12:53 +01001046 { "op_F32_F32_F32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float, 4>> },
1047 { "op_S16_S16_S16", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int16_t, 8>> },
1048 { "op_S32_S32_S32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int32_t, 4>> },
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001049 { "op_QASYMM8_QASYMM8_QASYMM8", &elementwise_arithm_op_quantized<op> },
1050 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED", &elementwise_arithm_op_quantized_signed<op> }
George Wortd88590f2018-12-12 17:39:58 +00001051 };
1052#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
giuros01d5134362019-05-14 16:12:53 +01001053 map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float16_t, 8>>;
George Wortd88590f2018-12-12 17:39:58 +00001054#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1055
1056 return configure_func(input1, input2, output, map_function);
1057}
1058
1059template <ComparisonOperation op>
1060std::function<void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)>
1061configure_comp_func(const ITensor *input1, const ITensor *input2, ITensor *output)
1062{
1063 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1064 {
1065 { "op_F32_F32_U8", &elementwise_comp_op_32<op, float, float32x4_t> },
1066 { "op_S16_S16_U8", &elementwise_comp_op_16<op, int16_t, int16x8_t> },
1067 { "op_S32_S32_U8", &elementwise_comp_op_32<op, int32_t, int32x4_t> },
morgolock74a16962020-01-15 11:40:49 +00001068 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_U8", &elementwise_comp_op_quantized_signed<op> },
George Wortd88590f2018-12-12 17:39:58 +00001069 { "op_QASYMM8_QASYMM8_U8", &elementwise_comp_op_quantized<op> }
1070 };
1071#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1072 map_function["op_F16_F16_U8"] = &elementwise_comp_op_16<op, float16_t, float16x8_t>;
1073#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1074
1075 return configure_func(input1, input2, output, map_function);
1076}
1077} // namespace
1078
1079NEElementwiseOperationKernel::NEElementwiseOperationKernel()
1080 : _function(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
1081{
1082}
1083
1084Status NEElementwiseOperationKernel::validate_arguments_common(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1085{
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001086 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
George Wortd88590f2018-12-12 17:39:58 +00001087 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
giuros0192fd9432018-12-03 17:30:00 +00001088 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
1089
1090 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
1091
1092 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1093
1094 // Validate in case of configured output
1095 if(output.total_size() > 0)
1096 {
giuros0192fd9432018-12-03 17:30:00 +00001097 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
1098 "Wrong shape for output");
1099 }
1100
1101 return Status{};
1102}
giuros0192fd9432018-12-03 17:30:00 +00001103
giuros0192fd9432018-12-03 17:30:00 +00001104void NEElementwiseOperationKernel::configure_common(const ITensor *input1, const ITensor *input2, ITensor *output)
1105{
1106 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001107
1108 // Configure kernel window
1109 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1->info(), *input2->info());
1110 const TensorShape &out_shape = broadcast_pair.first;
1111 const ValidRegion &valid_region = broadcast_pair.second;
1112
1113 // Auto initialize output if not initialized
1114 auto_init_if_empty(*output->info(), out_shape, 1, input1->info()->data_type());
1115
1116 Window win = calculate_max_window(valid_region);
1117
giuros0192fd9432018-12-03 17:30:00 +00001118 _input1 = input1;
1119 _input2 = input2;
1120 _output = output;
1121
giuros0192fd9432018-12-03 17:30:00 +00001122 INEKernel::configure(win);
1123}
1124
1125void NEElementwiseOperationKernel::run(const Window &window, const ThreadInfo &info)
1126{
George Wortd88590f2018-12-12 17:39:58 +00001127 ARM_COMPUTE_UNUSED(info, window);
giuros0192fd9432018-12-03 17:30:00 +00001128 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1129 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
George Wortd88590f2018-12-12 17:39:58 +00001130 ARM_COMPUTE_ERROR_ON(_function == nullptr);
1131 _function(_input1, _input2, _output, window);
giuros0192fd9432018-12-03 17:30:00 +00001132}
1133
1134/** Arithmetic operators (min, max, squared_diff) */
1135
1136void NEArithmeticOperationKernel::configure(ArithmeticOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
1137{
George Wortd88590f2018-12-12 17:39:58 +00001138 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1139 configure_common(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001140 switch(op)
1141 {
1142 case ArithmeticOperation::MAX:
George Wortd88590f2018-12-12 17:39:58 +00001143 _function = configure_arithm_func<ArithmeticOperation::MAX>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001144 break;
1145 case ArithmeticOperation::MIN:
George Wortd88590f2018-12-12 17:39:58 +00001146 _function = configure_arithm_func<ArithmeticOperation::MIN>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001147 break;
1148 case ArithmeticOperation::SQUARED_DIFF:
George Wortd88590f2018-12-12 17:39:58 +00001149 _function = configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001150 break;
giuros01d5134362019-05-14 16:12:53 +01001151 case ArithmeticOperation::PRELU:
1152 _function = configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
1153 break;
giuros0192fd9432018-12-03 17:30:00 +00001154 default:
1155 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1156 }
1157}
1158
George Wortd88590f2018-12-12 17:39:58 +00001159Status NEArithmeticOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1160{
1161 // Validate in case of configured output
1162 if(output.total_size() > 0)
1163 {
1164 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
1165 }
1166 return validate_arguments_common(input1, input2, output);
1167}
1168
giuros0192fd9432018-12-03 17:30:00 +00001169Status NEArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1170{
1171 ARM_COMPUTE_UNUSED(op);
1172 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
George Wortd88590f2018-12-12 17:39:58 +00001173 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
giuros0192fd9432018-12-03 17:30:00 +00001174 return Status{};
1175}
1176
George Worta1e7e282019-01-15 11:00:29 +00001177/** The division operator */
1178
1179void NEDivisionOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1180{
1181 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1182 configure_common(input1, input2, output);
1183 _function = configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output);
1184}
1185
1186Status NEDivisionOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1187{
1188 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1189 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1190}
1191
1192Status NEDivisionOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1193{
1194 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1195 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1196 return Status{};
1197}
1198
Usama Arif81e671e2019-05-13 13:33:14 +01001199/** The power operator */
1200void NEPowerOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1201{
1202 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1203 configure_common(input1, input2, output);
1204 _function = configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
1205}
1206
1207Status NEPowerOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1208{
1209 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1210 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1211}
1212
1213Status NEPowerOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1214{
1215 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1216 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1217 return Status{};
1218}
1219
George Wortd88590f2018-12-12 17:39:58 +00001220/** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
1221
1222void NEComparisonOperationKernel::configure(ComparisonOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
giuros0192fd9432018-12-03 17:30:00 +00001223{
George Wortd88590f2018-12-12 17:39:58 +00001224 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1225 configure_common(input1, input2, output);
1226 switch(op)
1227 {
1228 case ComparisonOperation::Equal:
1229 _function = configure_comp_func<ComparisonOperation::Equal>(input1, input2, output);
1230 break;
1231 case ComparisonOperation::NotEqual:
1232 _function = configure_comp_func<ComparisonOperation::NotEqual>(input1, input2, output);
1233 break;
1234 case ComparisonOperation::Greater:
1235 _function = configure_comp_func<ComparisonOperation::Greater>(input1, input2, output);
1236 break;
1237 case ComparisonOperation::GreaterEqual:
1238 _function = configure_comp_func<ComparisonOperation::GreaterEqual>(input1, input2, output);
1239 break;
1240 case ComparisonOperation::Less:
1241 _function = configure_comp_func<ComparisonOperation::Less>(input1, input2, output);
1242 break;
1243 case ComparisonOperation::LessEqual:
1244 _function = configure_comp_func<ComparisonOperation::LessEqual>(input1, input2, output);
1245 break;
1246 default:
1247 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1248 }
1249}
1250
1251Status NEComparisonOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1252{
1253 // Validate in case of configured output
1254 if(output.total_size() > 0)
1255 {
1256 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8);
1257 }
1258 return validate_arguments_common(input1, input2, output);
1259}
1260
1261Status NEComparisonOperationKernel::validate(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1262{
1263 ARM_COMPUTE_UNUSED(op);
1264 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1265 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1266 return Status{};
giuros0192fd9432018-12-03 17:30:00 +00001267}
1268} // namespace arm_compute