blob: 8d74b4027b06bf20acd40e840404be992d099539 [file] [log] [blame]
Sheri Zhang61243902021-01-12 18:25:16 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/core/cpu/kernels/CpuAddKernel.h"
25
26#include "arm_compute/core/ITensor.h"
27#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Validate.h"
29#include "src/core/CPP/Validate.h"
30#include "src/core/common/Registrars.h"
31#include "src/core/cpu/kernels/add/neon/list.h"
32#include "src/core/cpu/kernels/add/sve/list.h"
33#include "src/core/helpers/AutoConfiguration.h"
34#include "src/core/helpers/WindowHelpers.h"
35
36#include <array>
37
38namespace arm_compute
39{
40namespace cpu
41{
42namespace kernels
43{
44namespace
45{
46struct AddSelectorData
47{
Michalis Spyrou20fca522021-06-07 14:23:57 +010048 /* Data types for all ITensorInfos:
49 dt1 -> src0
50 dt2 -> src1
51 dt3 -> dst
52 */
53 DataType dt1;
54 DataType dt2;
55 DataType dt3;
56 const CPUInfo &ci;
Sheri Zhang61243902021-01-12 18:25:16 +000057};
58
59using AddSelectorPtr = std::add_pointer<bool(const AddSelectorData &data)>::type;
60using AddKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const ConvertPolicy &, const Window &)>::type;
61struct AddKernel
62{
63 const char *name;
64 const AddSelectorPtr is_selected;
65 AddKernelPtr ukernel;
66};
67
68static const AddKernel available_kernels[] =
69{
Michalis Spyrou20fca522021-06-07 14:23:57 +010070#if defined(ARM_COMPUTE_ENABLE_SVE2)
71 {
72 "add_qasymm8_sve",
73 [](const AddSelectorData & data)
74 {
75 return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)) && data.ci.has_sve();
76 },
77 REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve)
78 },
79 {
80 "add_qasymm8_signed_sve",
81 [](const AddSelectorData & data)
82 {
83 return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)) && data.ci.has_sve();
84 },
85 REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve)
86 },
87 {
88 "add_qsymm16_sve",
89 [](const AddSelectorData & data)
90 {
91 return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)) && data.ci.has_sve();
92 },
93 REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve)
94 },
95#endif /* !defined(ARM_COMPUTE_ENABLE_SVE2) */
96#if defined(ARM_COMPUTE_ENABLE_SVE)
Sheri Zhang61243902021-01-12 18:25:16 +000097 {
98 "add_same_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +010099 [](const AddSelectorData & data)
100 {
101 return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)) && data.ci.has_sve();
102 },
Sheri Zhang61243902021-01-12 18:25:16 +0000103 REGISTER_FP32_SVE(arm_compute::cpu::add_same_sve<float>)
104 },
105 {
106 "add_same_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100107 [](const AddSelectorData & data)
108 {
109 return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_sve();
110 },
Sheri Zhang61243902021-01-12 18:25:16 +0000111 REGISTER_FP16_SVE(arm_compute::cpu::add_same_sve<float16_t>)
112 },
113 {
114 "add_same_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100115 [](const AddSelectorData & data)
116 {
117 return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)) && data.ci.has_sve();
118 },
Sheri Zhang61243902021-01-12 18:25:16 +0000119 REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<uint8_t>)
120 },
121 {
122 "add_same_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100123 [](const AddSelectorData & data)
124 {
125 return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)) && data.ci.has_sve();
126 },
Sheri Zhang61243902021-01-12 18:25:16 +0000127 REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int16_t>)
128 },
129 {
130 "add_same_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100131 [](const AddSelectorData & data)
132 {
133 return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)) && data.ci.has_sve();
134 },
Sheri Zhang61243902021-01-12 18:25:16 +0000135 REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int32_t>)
136 },
137 {
138 "add_u8_s16_s16_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100139 [](const AddSelectorData & data)
140 {
141 return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)) && data.ci.has_sve();
142 },
Sheri Zhang61243902021-01-12 18:25:16 +0000143 REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_s16_s16_sve)
144 },
145 {
146 "add_s16_u8_s16_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100147 [](const AddSelectorData & data)
148 {
149 return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)) && data.ci.has_sve();
150 },
Sheri Zhang61243902021-01-12 18:25:16 +0000151 REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_u8_s16_sve)
152 },
153 {
154 "add_u8_u8_s16_sve",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100155 [](const AddSelectorData & data)
156 {
157 return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)) && data.ci.has_sve();
158 },
Sheri Zhang61243902021-01-12 18:25:16 +0000159 REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_u8_s16_sve)
160 },
Michalis Spyrou20fca522021-06-07 14:23:57 +0100161#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
162#if defined(ARM_COMPUTE_ENABLE_NEON)
Sheri Zhang61243902021-01-12 18:25:16 +0000163 {
164 "add_same_neon",
165 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); },
166 REGISTER_FP32_NEON(arm_compute::cpu::add_same_neon<float>)
167 },
168#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
169 {
170 "add_same_neon",
Michalis Spyrou20fca522021-06-07 14:23:57 +0100171 [](const AddSelectorData & data)
172 {
173 return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_fp16();
174 },
Sheri Zhang61243902021-01-12 18:25:16 +0000175 REGISTER_FP16_NEON(arm_compute::cpu::add_same_neon<float16_t>)
176 },
177#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
178 {
179 "add_same_neon",
180 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); },
181 REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon<uint8_t>)
182 },
183 {
184 "add_same_neon",
185 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); },
186 REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon<int16_t>)
187 },
188 {
189 "add_same_neon",
190 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); },
191 REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon<int32_t>)
192 },
193 {
194 "add_u8_s16_s16_neon",
195 [](const AddSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); },
196 REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_s16_s16_neon)
197 },
198 {
199 "add_s16_u8_s16_neon",
200 [](const AddSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); },
201 REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_u8_s16_neon)
202 },
203 {
204 "add_u8_u8_s16_neon",
205 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); },
206 REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_u8_s16_neon)
207 },
Michalis Spyrou20fca522021-06-07 14:23:57 +0100208#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
209#if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE)
Sheri Zhang61243902021-01-12 18:25:16 +0000210 {
211 "add_qasymm8_neon",
212 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); },
213 REGISTER_QASYMM8_NEON(arm_compute::cpu::add_qasymm8_neon)
214 },
215 {
216 "add_qasymm8_signed_neon",
217 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); },
218 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_qasymm8_signed_neon)
219 },
220 {
221 "add_qsymm16_neon",
222 [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); },
223 REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon)
224 },
Michalis Spyrou20fca522021-06-07 14:23:57 +0100225#endif /* defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) */
Sheri Zhang61243902021-01-12 18:25:16 +0000226};
227
228/** Micro-kernel selector
229 *
230 * @param[in] data Selection data passed to help pick the appropriate micro-kernel
231 *
232 * @return A matching micro-kernel else nullptr
233 */
Michalis Spyrou20fca522021-06-07 14:23:57 +0100234const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt1, DataType dt2, DataType dt3)
Sheri Zhang61243902021-01-12 18:25:16 +0000235{
236 for(const auto &uk : available_kernels)
237 {
Michalis Spyrou20fca522021-06-07 14:23:57 +0100238 if(uk.is_selected({ dt1, dt2, dt3, cpuinfo }))
Sheri Zhang61243902021-01-12 18:25:16 +0000239 {
240 return &uk;
241 }
242 }
243 return nullptr;
244}
245
246Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst, ConvertPolicy policy)
247{
248 ARM_COMPUTE_UNUSED(policy);
249
250 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0);
251 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
252 DataType::S16, DataType::QSYMM16, DataType::F16,
253 DataType::S32, DataType::F32);
254 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
255 DataType::S16, DataType::QSYMM16, DataType::F16,
256 DataType::S32, DataType::F32);
257
258 const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
259
260 ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
261 ARM_COMPUTE_RETURN_ERROR_ON_MSG((src0.tensor_shape().x() != src1.tensor_shape().x()) && ((src0.data_type() != src1.data_type()) || (src0.data_type() != dst.data_type())
262 || (src1.data_type() != dst.data_type())),
263 "Broadcasting across width is supported on configurations where all tensors have the same data type");
264
265 // Validate in case of configured dst
266 if(dst.total_size() > 0)
267 {
268 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
269 !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::U8)
270 && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16)
271 && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16)
272 && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16)
273 && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16)
274 && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32 && dst.data_type() == DataType::S32)
275 && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32 && dst.data_type() == DataType::F32)
276 && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16 && dst.data_type() == DataType::F16)
277 && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && dst.data_type() == DataType::QASYMM8)
278 && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && dst.data_type() == DataType::QASYMM8_SIGNED)
279 && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && dst.data_type() == DataType::QSYMM16),
280 "You called addition with the wrong image formats");
281
282 ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0),
283 "Wrong shape for dst");
284 }
285
Michalis Spyrou20fca522021-06-07 14:23:57 +0100286 const auto *uk = get_implementation(CPUInfo::get(), src0.data_type(), src1.data_type(), dst.data_type());
Sheri Zhang61243902021-01-12 18:25:16 +0000287 ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
288
289 return Status{};
290}
291
292std::pair<Status, Window> validate_and_configure_window(const ITensorInfo &src0, const ITensorInfo &src1, ITensorInfo &dst)
293{
SiCongLic7b1e842021-02-22 14:28:33 +0000294 const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
Sheri Zhang61243902021-01-12 18:25:16 +0000295
296 // Auto initialize dst if not initialized
297 {
298 set_shape_if_empty(dst, out_shape);
299
300 if(src0.data_type() == DataType::S16 || src1.data_type() == DataType::S16)
301 {
302 set_format_if_unknown(dst, Format::S16);
303 }
304 if(src0.data_type() == DataType::S32 || src1.data_type() == DataType::S32)
305 {
306 set_format_if_unknown(dst, Format::S32);
307 }
308 else if(src0.data_type() == DataType::F16 || src1.data_type() == DataType::F16)
309 {
310 set_format_if_unknown(dst, Format::F16);
311 }
312 else if(src0.data_type() == DataType::F32 || src1.data_type() == DataType::F32)
313 {
314 set_format_if_unknown(dst, Format::F32);
315 }
316 else if(src0.data_type() == DataType::QASYMM8 || src1.data_type() == DataType::QASYMM8)
317 {
318 set_data_type_if_unknown(dst, DataType::QASYMM8);
319 }
320 else if(src0.data_type() == DataType::QASYMM8_SIGNED || src1.data_type() == DataType::QASYMM8_SIGNED)
321 {
322 set_data_type_if_unknown(dst, DataType::QASYMM8_SIGNED);
323 }
324 else if(src0.data_type() == DataType::QSYMM16 || src1.data_type() == DataType::QSYMM16)
325 {
326 set_data_type_if_unknown(dst, DataType::QSYMM16);
327 }
328 }
329
SiCongLic7b1e842021-02-22 14:28:33 +0000330 Window win = calculate_max_window(out_shape, Steps());
Sheri Zhang61243902021-01-12 18:25:16 +0000331
332 // CpuAddKernel doesn't need padding so update_window_and_padding() can be skipped
Sheri Zhang61243902021-01-12 18:25:16 +0000333 return std::make_pair(Status{}, win);
334}
335} // namespace
336
337void CpuAddKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst, ConvertPolicy policy)
338{
339 ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
340 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst, policy));
341
342 _policy = policy;
343
344 // Configure kernel window
345 auto win_config = validate_and_configure_window(*src0, *src1, *dst);
346 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
347 ICpuKernel::configure(win_config.second);
348}
349
350Status CpuAddKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, ConvertPolicy policy)
351{
352 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
353
354 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst, policy));
355 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*src0->clone(), *src1->clone(), *dst->clone()).first);
356
357 return Status{};
358}
359
360void CpuAddKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
361{
362 ARM_COMPUTE_UNUSED(info);
363 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
364 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
365
366 ARM_COMPUTE_ERROR_ON(tensors.empty());
367
368 const ITensor *src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
369 const ITensor *src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
370 ITensor *dst = tensors.get_tensor(TensorType::ACL_DST);
371
Michalis Spyrou20fca522021-06-07 14:23:57 +0100372 const auto *uk = get_implementation(CPUInfo::get(), src0->info()->data_type(), src1->info()->data_type(), dst->info()->data_type());
Sheri Zhang61243902021-01-12 18:25:16 +0000373 ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
374
375 uk->ukernel(src0, src1, dst, _policy, window);
376}
377
378const char *CpuAddKernel::name() const
379{
380 return "CpuAddKernel";
381}
382} // namespace kernels
383} // namespace cpu
384} // namespace arm_compute