blob: 3fd5f39e9f4fa7a00c864288170e37664799bdb8 [file] [log] [blame]
giuros0192fd9432018-12-03 17:30:00 +00001/*
morgolocka3598052019-12-31 12:20:47 +00002 * Copyright (c) 2018-2020 ARM Limited.
giuros0192fd9432018-12-03 17:30:00 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEElementwiseOperationKernel.h"
25
26#include "arm_compute/core/CPP/Validate.h"
giuros0192fd9432018-12-03 17:30:00 +000027#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/IAccessWindow.h"
giuros0192fd9432018-12-03 17:30:00 +000029#include "arm_compute/core/NEON/NEAsymm.h"
30#include "arm_compute/core/NEON/NEFixedPoint.h"
31#include "arm_compute/core/NEON/wrapper/wrapper.h"
giuros0192fd9432018-12-03 17:30:00 +000032
giuros0192fd9432018-12-03 17:30:00 +000033#include <arm_neon.h>
giuros0192fd9432018-12-03 17:30:00 +000034#include <map>
giuros0192fd9432018-12-03 17:30:00 +000035
36namespace arm_compute
37{
giuros0192fd9432018-12-03 17:30:00 +000038namespace
39{
40float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
41{
42 qasymm8x16_t x = vld1q_u8(input1_ptr);
43 const float32x4x4_t out =
44 {
45 {
46 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
47 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
48 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
49 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
50 }
51 };
52 return out;
53}
54
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000055float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
56{
57 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
58 const float32x4x4_t out =
59 {
60 {
61 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
62 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
63 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
64 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
65 }
66 };
67 return out;
68}
69
George Wortd88590f2018-12-12 17:39:58 +000070void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
71{
72 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
73 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
74 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
75}
76
77void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
78{
79 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
80 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
81 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
82}
83
giuros0192fd9432018-12-03 17:30:00 +000084void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
85{
86 int32x4x4_t out =
87 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +000088 {
89 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
90 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
91 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
92 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
93 }
giuros0192fd9432018-12-03 17:30:00 +000094 };
George Wortd88590f2018-12-12 17:39:58 +000095 store_quantized(output_ptr, out);
giuros0192fd9432018-12-03 17:30:00 +000096}
97
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +000098void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
99{
100 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
101 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
102 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
103}
104
105void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
106{
107 int32x4x4_t out =
108 {
109 {
110 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
111 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
112 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
113 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
114 }
115 };
116 store_quantized_signed(output_ptr, out);
117}
118
giuros0192fd9432018-12-03 17:30:00 +0000119template <ArithmeticOperation op, typename ScalarType>
George Wortd88590f2018-12-12 17:39:58 +0000120inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
giuros0192fd9432018-12-03 17:30:00 +0000121{
122 auto res = ScalarType(0);
123
124 switch(op)
125 {
126 case ArithmeticOperation::MAX:
127 res = std::max(a, b);
128 break;
129 case ArithmeticOperation::MIN:
130 res = std::min(a, b);
131 break;
132 case ArithmeticOperation::SQUARED_DIFF:
133 {
134 res = (a - b) * (a - b);
135 break;
136 }
giuros01d5134362019-05-14 16:12:53 +0100137 case ArithmeticOperation::PRELU:
138 {
139 res = (a > 0 ? a : a * b);
140 break;
141 }
George Worta1e7e282019-01-15 11:00:29 +0000142 case ArithmeticOperation::DIV:
143 {
144 res = a / b;
145 break;
146 }
Usama Arif81e671e2019-05-13 13:33:14 +0100147 case ArithmeticOperation::POWER:
148 {
149 res = std::pow(a, b);
150 break;
151 }
giuros0192fd9432018-12-03 17:30:00 +0000152 default:
153 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
154 }
155 return res;
156}
157
George Wortd88590f2018-12-12 17:39:58 +0000158template <ArithmeticOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100159inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000160{
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100161 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
George Wortd88590f2018-12-12 17:39:58 +0000162}
163
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000164template <ArithmeticOperation op>
165inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
166{
167 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
168}
169
giuros01d5134362019-05-14 16:12:53 +0100170template <ArithmeticOperation op, typename VectorType>
171inline typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
giuros0192fd9432018-12-03 17:30:00 +0000172{
giuros01d5134362019-05-14 16:12:53 +0100173 using vec_type = typename VectorType::type;
174 using scalar_type = typename VectorType::scalar_type;
175 using tag_type = typename VectorType::tag_type;
176
177 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
giuros0192fd9432018-12-03 17:30:00 +0000178
179 switch(op)
180 {
181 case ArithmeticOperation::MAX:
182 res = wrapper::vmax(a, b);
183 break;
184 case ArithmeticOperation::MIN:
185 res = wrapper::vmin(a, b);
186 break;
187 case ArithmeticOperation::SQUARED_DIFF:
188 {
giuros01d5134362019-05-14 16:12:53 +0100189 const vec_type tmp = wrapper::vsub(a, b);
190 res = wrapper::vmul(tmp, tmp);
giuros0192fd9432018-12-03 17:30:00 +0000191 break;
192 }
giuros01d5134362019-05-14 16:12:53 +0100193 case ArithmeticOperation::PRELU:
194 {
195 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
196 const vec_type tmp = wrapper::vmul(a, b);
197 const auto gt = wrapper::vcgt(a, zero);
198
199 res = wrapper::vbsl(gt, a, tmp);
200 break;
201 }
202
giuros0192fd9432018-12-03 17:30:00 +0000203 default:
204 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
205 }
206
207 return res;
208}
209
George Worta1e7e282019-01-15 11:00:29 +0000210template <>
giuros01d5134362019-05-14 16:12:53 +0100211inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000212{
213 return wrapper::vdiv(a, b);
214}
215
Usama Arif81e671e2019-05-13 13:33:14 +0100216template <>
giuros01d5134362019-05-14 16:12:53 +0100217inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100218{
219 return wrapper::vpow(a, b);
220}
221
George Worta1e7e282019-01-15 11:00:29 +0000222#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
223template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100224inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
George Worta1e7e282019-01-15 11:00:29 +0000225{
226 return wrapper::vdiv(a, b);
227}
Usama Arif81e671e2019-05-13 13:33:14 +0100228
229template <>
Michele Di Giorgiob3a0a602019-06-13 15:35:00 +0100230inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
Usama Arif81e671e2019-05-13 13:33:14 +0100231{
232 return wrapper::vpow(a, b);
233}
George Worta1e7e282019-01-15 11:00:29 +0000234#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
235
giuros0192fd9432018-12-03 17:30:00 +0000236template <ArithmeticOperation op>
George Wortd88590f2018-12-12 17:39:58 +0000237inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
giuros0192fd9432018-12-03 17:30:00 +0000238{
giuros01d5134362019-05-14 16:12:53 +0100239 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
giuros0192fd9432018-12-03 17:30:00 +0000240 float32x4x4_t out =
241 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000242 {
giuros01d5134362019-05-14 16:12:53 +0100243 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
244 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
245 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
246 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000247 }
giuros0192fd9432018-12-03 17:30:00 +0000248 };
249 return out;
250}
251
giuros01d5134362019-05-14 16:12:53 +0100252template <ArithmeticOperation op, typename ScalarType, typename VectorType>
253inline typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
George Wortd88590f2018-12-12 17:39:58 +0000254{
giuros01d5134362019-05-14 16:12:53 +0100255 using tag_type = typename VectorType::tag_type;
256 using vec_type = typename VectorType::type;
257
258 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
259 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
George Wortd88590f2018-12-12 17:39:58 +0000260}
261
262template <ComparisonOperation op, typename InputScalarType>
263inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
264{
265 bool res = false;
266
267 switch(op)
268 {
269 case ComparisonOperation::Equal:
270 res = (a == b);
271 break;
272 case ComparisonOperation::NotEqual:
273 res = (a != b);
274 break;
275 case ComparisonOperation::Greater:
276 res = (a > b);
277 break;
278 case ComparisonOperation::GreaterEqual:
279 res = (a >= b);
280 break;
281 case ComparisonOperation::Less:
282 res = (a < b);
283 break;
284 case ComparisonOperation::LessEqual:
285 res = (a <= b);
286 break;
287 default:
288 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
289 }
290 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
291}
292
293template <ComparisonOperation op>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100294inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
George Wortd88590f2018-12-12 17:39:58 +0000295{
296 ARM_COMPUTE_UNUSED(qinfo);
297 return elementwise_comp_op_scalar<op>(a, b);
298}
299
300template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
301inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
302{
303 OutputVectorType res = { 0, 0, 0, 0 };
304
305 switch(op)
306 {
307 case ComparisonOperation::Equal:
308 res = wrapper::vceq(a, b);
309 break;
310 case ComparisonOperation::NotEqual:
311 res = wrapper::vnot(wrapper::vceq(a, b));
312 break;
313 case ComparisonOperation::Greater:
314 res = wrapper::vcgt(a, b);
315 break;
316 case ComparisonOperation::GreaterEqual:
317 res = wrapper::vcge(a, b);
318 break;
319 case ComparisonOperation::Less:
320 res = wrapper::vcgt(b, a);
321 break;
322 case ComparisonOperation::LessEqual:
323 res = wrapper::vcge(b, a);
324 break;
325 default:
326 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
327 }
328
329 return res;
330}
331
332template <ComparisonOperation op>
333inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
334{
335 uint32x4x4_t out =
336 {
Georgios Pinitasd57891a2019-02-19 18:10:03 +0000337 {
338 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
339 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
340 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
341 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
342 }
George Wortd88590f2018-12-12 17:39:58 +0000343 };
344 return out;
345}
346
347template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
348inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
349{
350 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
351 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
352}
353
354template <ArithmeticOperation op, typename ScalarType, typename VectorType>
355inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
356 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
357{
358 int x = window_start_x;
359 for(; x <= (window_end_x - window_step_x); x += window_step_x)
360 {
361 const auto a = wrapper::vloadq(input1_ptr + x);
362 const auto b = wrapper::vloadq(input2_ptr + x);
giuros01d5134362019-05-14 16:12:53 +0100363 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
George Wortd88590f2018-12-12 17:39:58 +0000364 }
365 return x;
366}
367
368template <ArithmeticOperation op>
369inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
370 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
371 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
372 float32x4_t voffseto, float32x4_t invvscaleo)
373{
374 int x = window_start_x;
375 for(; x <= (window_end_x - window_step_x); x += window_step_x)
376 {
377 // Get inputs and compute output
378 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
379 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
380 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
381 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
382 }
383 return x;
384}
385
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000386template <ArithmeticOperation op>
387inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
388 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
389 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
390 float32x4_t voffseto, float32x4_t invvscaleo)
391{
392 int x = window_start_x;
393 for(; x <= (window_end_x - window_step_x); x += window_step_x)
394 {
395 // Get inputs and compute output
396 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
397 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
398 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
399 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
400 }
401 return x;
402}
403
George Wortd88590f2018-12-12 17:39:58 +0000404template <ArithmeticOperation op, typename ScalarType, typename VectorType>
405inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
406 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
407{
408 int x = window_start_x;
409 for(; x <= (window_end_x - window_step_x); x += window_step_x)
410 {
411 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
giuros01d5134362019-05-14 16:12:53 +0100412 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
George Wortd88590f2018-12-12 17:39:58 +0000413 }
414 return x;
415}
416
417template <ArithmeticOperation op>
418inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
419 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
420 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
421 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
422{
423 int x = window_start_x;
424 for(; x <= (window_end_x - window_step_x); x += window_step_x)
425 {
426 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
427 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
428 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
429 }
430 return x;
431}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000432template <ArithmeticOperation op>
433inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
434 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
435 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
436 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
437{
438 int x = window_start_x;
439 for(; x <= (window_end_x - window_step_x); x += window_step_x)
440 {
441 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
442 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
443 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
444 }
445 return x;
446}
George Wortd88590f2018-12-12 17:39:58 +0000447
448template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
449inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
450 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
451{
452 int x = window_start_x;
453 for(; x <= (window_end_x - window_step_x); x += window_step_x)
454 {
455 const auto a = wrapper::vloadq(input1_ptr + x);
456 const auto b = wrapper::vloadq(input2_ptr + x);
457 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
458 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
459 }
460 return x;
461}
462
463template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
464inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
465 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
466{
467 int x = window_start_x;
468 for(; x <= (window_end_x - window_step_x); x += window_step_x)
469 {
470 auto a = wrapper::vloadq(input1_ptr + x);
471 auto b = wrapper::vloadq(input2_ptr + x);
472 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
473 a = wrapper::vloadq(input1_ptr + x + 4);
474 b = wrapper::vloadq(input2_ptr + x + 4);
475 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
476 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
477 }
478 if(x <= window_end_x - 4)
479 {
480 const auto a = wrapper::vloadq(input1_ptr + x);
481 const auto b = wrapper::vloadq(input2_ptr + x);
482 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
483 for(int i = 0; i < 4; i++)
484 {
485 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
486 }
487 x = +4;
488 }
489 return x;
490}
491
492template <ComparisonOperation op>
493inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
494 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
495 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
496 float32x4_t voffseto, float32x4_t invvscaleo)
497{
498 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
499 int x = window_start_x;
500 for(; x <= (window_end_x - window_step_x); x += window_step_x)
501 {
502 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
503 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
504 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
505 store_quantized(output_ptr + x, rf);
506 }
507 return x;
508}
509
morgolock74a16962020-01-15 11:40:49 +0000510template <ComparisonOperation op>
511inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
512 const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
513 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
514 float32x4_t voffseto, float32x4_t invvscaleo)
515{
516 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
517 int x = window_start_x;
518 for(; x <= (window_end_x - window_step_x); x += window_step_x)
519 {
520 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
521 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
522 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
523 store_quantized(output_ptr + x, rf);
524 }
525 return x;
526}
527
George Wortd88590f2018-12-12 17:39:58 +0000528template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
529inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
530 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
531{
532 int x = window_start_x;
533 for(; x <= (window_end_x - window_step_x); x += window_step_x)
534 {
535 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
536 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
537 }
538 return x;
539}
540
541template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
542inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
543 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
544{
545 int x = window_start_x;
546 for(; x <= (window_end_x - window_step_x); x += window_step_x)
547 {
548 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
549 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
550 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
551 }
552 if(x <= window_end_x - 4)
553 {
554 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
555 for(int i = 0; i < 4; i++)
556 {
557 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
558 }
559 x = +4;
560 }
561 return x;
562}
563
564template <ComparisonOperation op>
565inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
566 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
567 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
568 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
569{
570 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
571 int x = window_start_x;
572 for(; x <= (window_end_x - window_step_x); x += window_step_x)
573 {
574 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
575 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
576 store_quantized(output_ptr + x, rf);
577 }
578 return x;
579}
580
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100581template <ComparisonOperation op>
582inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
583 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
584 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
585 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
586{
587 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
588 int x = window_start_x;
589 for(; x <= (window_end_x - window_step_x); x += window_step_x)
590 {
591 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
592 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
593 store_quantized(output_ptr + x, rf);
594 }
595 return x;
596}
597
George Wortd88590f2018-12-12 17:39:58 +0000598template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
599void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
600 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
601 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
602 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
giuros0192fd9432018-12-03 17:30:00 +0000603{
604 // Create input windows
605 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
606 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
607
608 // Clear X Dimension on execution window as we handle manually
609 Window win = window;
610 win.set(Window::DimX, Window::Dimension(0, 1, 1));
611
Michalis Spyroue8c0c432019-01-22 11:08:31 +0000612 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
giuros0192fd9432018-12-03 17:30:00 +0000613 const auto window_start_x = static_cast<int>(window.x().start());
614 const auto window_end_x = static_cast<int>(window.x().end());
615 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
616
617 if(is_broadcast_across_x)
618 {
giuros0192fd9432018-12-03 17:30:00 +0000619 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
620 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
621 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
622 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
623 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
624
625 // Clear X Dimension on execution window as we handle manually
626 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
627
628 Iterator broadcast_input(broadcast_tensor, broadcast_win);
629 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
630 Iterator output(out, win);
631
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100632 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000633 {
George Wortd88590f2018-12-12 17:39:58 +0000634 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
635 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
636 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000637
George Wortd88590f2018-12-12 17:39:58 +0000638 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000639 for(; x < window_end_x; ++x)
640 {
641 const auto a = *(non_broadcast_input_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000642 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
giuros0192fd9432018-12-03 17:30:00 +0000643 }
644 },
645 broadcast_input, non_broadcast_input, output);
646 }
647 else
648 {
649 // Clear X Dimension on execution window as we handle manually
650 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
651 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
652
653 Iterator input1(in1, input1_win);
654 Iterator input2(in2, input2_win);
655 Iterator output(out, win);
656
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100657 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000658 {
George Wortd88590f2018-12-12 17:39:58 +0000659 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
660 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
661 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
giuros0192fd9432018-12-03 17:30:00 +0000662
George Wortd88590f2018-12-12 17:39:58 +0000663 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
giuros0192fd9432018-12-03 17:30:00 +0000664 for(; x < window_end_x; ++x)
665 {
666 const auto a = *(input1_ptr + x);
667 const auto b = *(input2_ptr + x);
George Wortd88590f2018-12-12 17:39:58 +0000668 *(output_ptr + x) = (*scalar_func)(a, b);
giuros0192fd9432018-12-03 17:30:00 +0000669 }
giuros0192fd9432018-12-03 17:30:00 +0000670 },
671 input1, input2, output);
672 }
673}
674
George Wortd88590f2018-12-12 17:39:58 +0000675void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100676 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
George Wortd88590f2018-12-12 17:39:58 +0000677 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
678 float32x4_t, float32x4_t, const bool),
679 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
680 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
681 float32x4_t, float32x4_t))
giuros0192fd9432018-12-03 17:30:00 +0000682{
683 // Create input windows
684 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
685 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
686
687 // Clear X Dimension on execution window as we handle manually
688 Window win = window;
689 win.set(Window::DimX, Window::Dimension(0, 1, 1));
690
691 const int window_step_x = 16;
692 const auto window_start_x = static_cast<int>(window.x().start());
693 const auto window_end_x = static_cast<int>(window.x().end());
694 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
695
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100696 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000697
698 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100699 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
700 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000701
702 if(is_broadcast_across_x)
703 {
704 // Select the broadcast input on the X axis
705 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
706 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
707 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
708 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
709 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
710
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100711 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
712 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
giuros0192fd9432018-12-03 17:30:00 +0000713
714 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
715 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
716
717 // Clear X Dimension on execution window as we handle manually
718 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
719
720 Iterator broadcast_input(broadcast_tensor, broadcast_win);
721 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
722 Iterator output(out, win);
723
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100724 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000725 {
726 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
727 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
728
729 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100730 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000731
George Wortd88590f2018-12-12 17:39:58 +0000732 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
733 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
giuros0192fd9432018-12-03 17:30:00 +0000734 for(; x < window_end_x; ++x)
735 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100736 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
737 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
738 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000739 }
740 },
741 broadcast_input, non_broadcast_input, output);
742 }
743 else
744 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100745 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
746 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
747
giuros0192fd9432018-12-03 17:30:00 +0000748 // Input1 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100749 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
750 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000751
752 // Input2 quantization info
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100753 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
754 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
giuros0192fd9432018-12-03 17:30:00 +0000755
756 // Clear X Dimension on execution window as we handle manually
757 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
758 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
759
giuros0192fd9432018-12-03 17:30:00 +0000760 Iterator input1(in1, input1_win);
761 Iterator input2(in2, input2_win);
762 Iterator output(out, win);
763
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100764 execute_window_loop(win, [&](const Coordinates &)
giuros0192fd9432018-12-03 17:30:00 +0000765 {
766 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
767 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
768 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
769
George Wortd88590f2018-12-12 17:39:58 +0000770 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
771 vscale1, vscale2, voffseto, invvscaleo);
giuros0192fd9432018-12-03 17:30:00 +0000772 for(; x < window_end_x; ++x)
773 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100774 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
775 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
776 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
giuros0192fd9432018-12-03 17:30:00 +0000777 }
778 },
779 input1, input2, output);
780 }
781}
782
morgolock74a16962020-01-15 11:40:49 +0000783void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
784 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100785 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
786 float32x4_t, float32x4_t, const bool),
morgolock74a16962020-01-15 11:40:49 +0000787 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
788 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
789 float32x4_t, float32x4_t))
790{
791 // Create input windows
792 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
793 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
794
795 // Clear X Dimension on execution window as we handle manually
796 Window win = window;
797 win.set(Window::DimX, Window::Dimension(0, 1, 1));
798
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100799 const int window_step_x = 16;
800 const auto window_start_x = static_cast<int>(window.x().start());
801 const auto window_end_x = static_cast<int>(window.x().end());
802 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
803
804 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
morgolock74a16962020-01-15 11:40:49 +0000805
806 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
807 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100808
809 if(is_broadcast_across_x)
810 {
811 // Select the broadcast input on the X axis
812 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
813 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
814 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
815 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
816 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
817
818 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
819 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
820
821 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
822 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
823
824 // Clear X Dimension on execution window as we handle manually
825 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
826
827 Iterator broadcast_input(broadcast_tensor, broadcast_win);
828 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
829 Iterator output(out, win);
830
831 execute_window_loop(win, [&](const Coordinates &)
832 {
833 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
834 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
835
836 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
837 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
838
839 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
840 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
841 for(; x < window_end_x; ++x)
842 {
843 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
844 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
845 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
846 }
847 },
848 broadcast_input, non_broadcast_input, output);
849 }
850 else
morgolock74a16962020-01-15 11:40:49 +0000851 {
852 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
853 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
854
855 // Input1 quantization info
856 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
857 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
858
859 // Input2 quantization info
860 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
861 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
862
863 // Clear X Dimension on execution window as we handle manually
864 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
865 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
866
867 Iterator input1(in1, input1_win);
868 Iterator input2(in2, input2_win);
869 Iterator output(out, win);
870
871 execute_window_loop(win, [&](const Coordinates &)
872 {
873 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
874 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
875 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
876
877 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
878 vscale1, vscale2, voffseto, invvscaleo);
879 for(; x < window_end_x; ++x)
880 {
881 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
882 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
883 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
884 }
885 },
886 input1, input2, output);
887 }
888}
889
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000890void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
891 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
892 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
893 float32x4_t, float32x4_t, const bool),
894 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
895 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
896 float32x4_t, float32x4_t))
897{
898 // Create input windows
899 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
900 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
901
902 // Clear X Dimension on execution window as we handle manually
903 Window win = window;
904 win.set(Window::DimX, Window::Dimension(0, 1, 1));
905
906 const int window_step_x = 16;
907 const auto window_start_x = static_cast<int>(window.x().start());
908 const auto window_end_x = static_cast<int>(window.x().end());
909 const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
910
911 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
912
morgolocka3598052019-12-31 12:20:47 +0000913 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000914 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
915
916 if(is_broadcast_across_x)
917 {
918 // Select the broadcast input on the X axis
919 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
920 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
921 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
922 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
923 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
924
925 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
926 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
927
928 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
929 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
930
931 // Clear X Dimension on execution window as we handle manually
932 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
933
934 Iterator broadcast_input(broadcast_tensor, broadcast_win);
935 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
936 Iterator output(out, win);
937
938 execute_window_loop(win, [&](const Coordinates &)
939 {
940 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
941 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
942
943 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
Michele Di Giorgio81870c02020-04-30 12:02:20 +0100944 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +0000945
946 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
947 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
948 for(; x < window_end_x; ++x)
949 {
950 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
951 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
952 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
953 }
954 },
955 broadcast_input, non_broadcast_input, output);
956 }
957 else
958 {
959 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
960 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
961
962 // Input1 quantization info
963 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
964 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
965
966 // Input2 quantization info
967 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
968 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
969
970 // Clear X Dimension on execution window as we handle manually
971 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
972 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
973
974 Iterator input1(in1, input1_win);
975 Iterator input2(in2, input2_win);
976 Iterator output(out, win);
977
978 execute_window_loop(win, [&](const Coordinates &)
979 {
980 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
981 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
982 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
983
984 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
985 vscale1, vscale2, voffseto, invvscaleo);
986 for(; x < window_end_x; ++x)
987 {
988 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
989 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
990 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
991 }
992 },
993 input1, input2, output);
994 }
995}
996
George Wortd88590f2018-12-12 17:39:58 +0000997template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
998void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
giuros0192fd9432018-12-03 17:30:00 +0000999{
George Wortd88590f2018-12-12 17:39:58 +00001000 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
1001 &elementwise_comp_op_scalar<op, InputScalarType>,
1002 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
1003 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
1004}
1005
1006template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
1007void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1008{
1009 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
1010 &elementwise_comp_op_scalar<op, InputScalarType>,
1011 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
1012 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
1013}
1014
giuros01d5134362019-05-14 16:12:53 +01001015template <ArithmeticOperation op, typename VectorType>
George Wortd88590f2018-12-12 17:39:58 +00001016void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1017{
giuros01d5134362019-05-14 16:12:53 +01001018 using scalar_type = typename VectorType::scalar_type;
1019
1020 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
1021 &elementwise_arithm_op_scalar<op, scalar_type>,
1022 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
1023 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
George Wortd88590f2018-12-12 17:39:58 +00001024}
1025
1026template <ArithmeticOperation op>
1027void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1028{
1029 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
1030 &elementwise_arithm_op_quantized_broadcast_loop<op>,
1031 &elementwise_arithm_op_quantized_loop<op>);
1032}
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001033template <ArithmeticOperation op>
1034void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1035{
1036 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
1037 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1038 &elementwise_arithm_op_quantized_singed_loop<op>);
1039}
George Wortd88590f2018-12-12 17:39:58 +00001040
1041template <ComparisonOperation op>
1042void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1043{
1044 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1045 &elementwise_comp_op_quantized_broadcast_loop<op>,
1046 &elementwise_comp_op_quantized_loop<op>);
1047}
1048
morgolock74a16962020-01-15 11:40:49 +00001049template <ComparisonOperation op>
1050void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1051{
Michele Di Giorgio81870c02020-04-30 12:02:20 +01001052 elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1053 &elementwise_comp_op_quantized_signed_broadcast_loop<op>,
1054 &elementwise_comp_op_quantized_signed_loop<op>);
morgolock74a16962020-01-15 11:40:49 +00001055}
1056
George Wortd88590f2018-12-12 17:39:58 +00001057std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
1058configure_func(const ITensor *input1, const ITensor *input2, ITensor *output,
1059 std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function)
1060{
1061 std::string function_to_call("op_");
1062 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
1063 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
1064 function_to_call += string_from_data_type(output->info()->data_type());
1065
1066 auto it = map_function.find(function_to_call);
1067
1068 if(it != map_function.end())
1069 {
1070 auto func = it->second;
1071 return [func](const ITensor * input1, const ITensor * input2, ITensor * output, const Window & window)
1072 {
1073 func(input1, input2, output, window);
1074 };
1075 }
1076 return nullptr;
1077}
1078
1079template <ArithmeticOperation op>
1080std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
1081configure_arithm_func(const ITensor *input1, const ITensor *input2, ITensor *output)
1082{
1083 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1084 {
giuros01d5134362019-05-14 16:12:53 +01001085 { "op_F32_F32_F32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float, 4>> },
1086 { "op_S16_S16_S16", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int16_t, 8>> },
1087 { "op_S32_S32_S32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int32_t, 4>> },
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001088 { "op_QASYMM8_QASYMM8_QASYMM8", &elementwise_arithm_op_quantized<op> },
1089 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED", &elementwise_arithm_op_quantized_signed<op> }
George Wortd88590f2018-12-12 17:39:58 +00001090 };
1091#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
giuros01d5134362019-05-14 16:12:53 +01001092 map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float16_t, 8>>;
George Wortd88590f2018-12-12 17:39:58 +00001093#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1094
1095 return configure_func(input1, input2, output, map_function);
1096}
1097
1098template <ComparisonOperation op>
1099std::function<void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)>
1100configure_comp_func(const ITensor *input1, const ITensor *input2, ITensor *output)
1101{
1102 static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
1103 {
1104 { "op_F32_F32_U8", &elementwise_comp_op_32<op, float, float32x4_t> },
1105 { "op_S16_S16_U8", &elementwise_comp_op_16<op, int16_t, int16x8_t> },
1106 { "op_S32_S32_U8", &elementwise_comp_op_32<op, int32_t, int32x4_t> },
morgolock74a16962020-01-15 11:40:49 +00001107 { "op_QASYMM8_SIGNED_QASYMM8_SIGNED_U8", &elementwise_comp_op_quantized_signed<op> },
George Wortd88590f2018-12-12 17:39:58 +00001108 { "op_QASYMM8_QASYMM8_U8", &elementwise_comp_op_quantized<op> }
1109 };
1110#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1111 map_function["op_F16_F16_U8"] = &elementwise_comp_op_16<op, float16_t, float16x8_t>;
1112#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1113
1114 return configure_func(input1, input2, output, map_function);
1115}
1116} // namespace
1117
1118NEElementwiseOperationKernel::NEElementwiseOperationKernel()
1119 : _function(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
1120{
1121}
1122
1123Status NEElementwiseOperationKernel::validate_arguments_common(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1124{
Michalis Spyrou8d4d1b82019-11-28 11:31:23 +00001125 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
George Wortd88590f2018-12-12 17:39:58 +00001126 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
giuros0192fd9432018-12-03 17:30:00 +00001127 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
1128
1129 const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
1130
1131 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1132
1133 // Validate in case of configured output
1134 if(output.total_size() > 0)
1135 {
giuros0192fd9432018-12-03 17:30:00 +00001136 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
1137 "Wrong shape for output");
1138 }
1139
1140 return Status{};
1141}
giuros0192fd9432018-12-03 17:30:00 +00001142
giuros0192fd9432018-12-03 17:30:00 +00001143void NEElementwiseOperationKernel::configure_common(const ITensor *input1, const ITensor *input2, ITensor *output)
1144{
1145 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001146
1147 // Configure kernel window
1148 const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1->info(), *input2->info());
1149 const TensorShape &out_shape = broadcast_pair.first;
1150 const ValidRegion &valid_region = broadcast_pair.second;
1151
1152 // Auto initialize output if not initialized
1153 auto_init_if_empty(*output->info(), out_shape, 1, input1->info()->data_type());
1154
1155 Window win = calculate_max_window(valid_region);
1156
giuros0192fd9432018-12-03 17:30:00 +00001157 _input1 = input1;
1158 _input2 = input2;
1159 _output = output;
1160
giuros0192fd9432018-12-03 17:30:00 +00001161 INEKernel::configure(win);
1162}
1163
1164void NEElementwiseOperationKernel::run(const Window &window, const ThreadInfo &info)
1165{
George Wortd88590f2018-12-12 17:39:58 +00001166 ARM_COMPUTE_UNUSED(info, window);
giuros0192fd9432018-12-03 17:30:00 +00001167 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1168 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
George Wortd88590f2018-12-12 17:39:58 +00001169 ARM_COMPUTE_ERROR_ON(_function == nullptr);
1170 _function(_input1, _input2, _output, window);
giuros0192fd9432018-12-03 17:30:00 +00001171}
1172
1173/** Arithmetic operators (min, max, squared_diff) */
1174
1175void NEArithmeticOperationKernel::configure(ArithmeticOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
1176{
George Wortd88590f2018-12-12 17:39:58 +00001177 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1178 configure_common(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001179 switch(op)
1180 {
1181 case ArithmeticOperation::MAX:
George Wortd88590f2018-12-12 17:39:58 +00001182 _function = configure_arithm_func<ArithmeticOperation::MAX>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001183 break;
1184 case ArithmeticOperation::MIN:
George Wortd88590f2018-12-12 17:39:58 +00001185 _function = configure_arithm_func<ArithmeticOperation::MIN>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001186 break;
1187 case ArithmeticOperation::SQUARED_DIFF:
George Wortd88590f2018-12-12 17:39:58 +00001188 _function = configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
giuros0192fd9432018-12-03 17:30:00 +00001189 break;
giuros01d5134362019-05-14 16:12:53 +01001190 case ArithmeticOperation::PRELU:
1191 _function = configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
1192 break;
giuros0192fd9432018-12-03 17:30:00 +00001193 default:
1194 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1195 }
1196}
1197
George Wortd88590f2018-12-12 17:39:58 +00001198Status NEArithmeticOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1199{
1200 // Validate in case of configured output
1201 if(output.total_size() > 0)
1202 {
1203 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
1204 }
1205 return validate_arguments_common(input1, input2, output);
1206}
1207
giuros0192fd9432018-12-03 17:30:00 +00001208Status NEArithmeticOperationKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1209{
1210 ARM_COMPUTE_UNUSED(op);
1211 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
George Wortd88590f2018-12-12 17:39:58 +00001212 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
giuros0192fd9432018-12-03 17:30:00 +00001213 return Status{};
1214}
1215
George Worta1e7e282019-01-15 11:00:29 +00001216/** The division operator */
1217
1218void NEDivisionOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1219{
1220 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1221 configure_common(input1, input2, output);
1222 _function = configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output);
1223}
1224
1225Status NEDivisionOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1226{
1227 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1228 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1229}
1230
1231Status NEDivisionOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1232{
1233 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1234 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1235 return Status{};
1236}
1237
Usama Arif81e671e2019-05-13 13:33:14 +01001238/** The power operator */
1239void NEPowerOperationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output)
1240{
1241 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1242 configure_common(input1, input2, output);
1243 _function = configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
1244}
1245
1246Status NEPowerOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1247{
1248 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
1249 return NEArithmeticOperationKernel::validate_arguments(input1, input2, output);
1250}
1251
1252Status NEPowerOperationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1253{
1254 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1255 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1256 return Status{};
1257}
1258
George Wortd88590f2018-12-12 17:39:58 +00001259/** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
1260
1261void NEComparisonOperationKernel::configure(ComparisonOperation op, const ITensor *input1, const ITensor *input2, ITensor *output)
giuros0192fd9432018-12-03 17:30:00 +00001262{
George Wortd88590f2018-12-12 17:39:58 +00001263 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info()));
1264 configure_common(input1, input2, output);
1265 switch(op)
1266 {
1267 case ComparisonOperation::Equal:
1268 _function = configure_comp_func<ComparisonOperation::Equal>(input1, input2, output);
1269 break;
1270 case ComparisonOperation::NotEqual:
1271 _function = configure_comp_func<ComparisonOperation::NotEqual>(input1, input2, output);
1272 break;
1273 case ComparisonOperation::Greater:
1274 _function = configure_comp_func<ComparisonOperation::Greater>(input1, input2, output);
1275 break;
1276 case ComparisonOperation::GreaterEqual:
1277 _function = configure_comp_func<ComparisonOperation::GreaterEqual>(input1, input2, output);
1278 break;
1279 case ComparisonOperation::Less:
1280 _function = configure_comp_func<ComparisonOperation::Less>(input1, input2, output);
1281 break;
1282 case ComparisonOperation::LessEqual:
1283 _function = configure_comp_func<ComparisonOperation::LessEqual>(input1, input2, output);
1284 break;
1285 default:
1286 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
1287 }
1288}
1289
1290Status NEComparisonOperationKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
1291{
1292 // Validate in case of configured output
1293 if(output.total_size() > 0)
1294 {
1295 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8);
1296 }
1297 return validate_arguments_common(input1, input2, output);
1298}
1299
1300Status NEComparisonOperationKernel::validate(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1301{
1302 ARM_COMPUTE_UNUSED(op);
1303 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
1304 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
1305 return Status{};
giuros0192fd9432018-12-03 17:30:00 +00001306}
1307} // namespace arm_compute