blob: 6a632e87022cfa2a33aca21551f667943bdc7bcb [file] [log] [blame]
Gunes Bayirae72a462023-01-29 13:24:24 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/cpu/kernels/CpuAddMulAddKernel.h"
25
26#include "arm_compute/core/ITensor.h"
27#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Validate.h"
29
Gunes Bayirae72a462023-01-29 13:24:24 +000030#include "src/core/common/Registrars.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010031#include "src/core/CPP/Validate.h"
Gunes Bayirae72a462023-01-29 13:24:24 +000032#include "src/core/helpers/AutoConfiguration.h"
33#include "src/core/helpers/WindowHelpers.h"
34#include "src/cpu/kernels/addmuladd/list.h"
35
36namespace arm_compute
37{
38namespace cpu
39{
40namespace kernels
41{
42namespace
43{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010044static const std::vector<CpuAddMulAddKernel::AddMulAddKernel> available_kernels = {
Gunes Bayirae72a462023-01-29 13:24:24 +000045#ifdef __aarch64__
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010046 {"neon_fp32_add_mul_add", [](const DataTypeISASelectorData &data) { return (data.dt == DataType::F32); },
47 REGISTER_FP32_NEON(arm_compute::cpu::add_mul_add_fp32_neon)},
48 {"neon_fp16_add_mul_add", [](const DataTypeISASelectorData &data) { return (data.dt == DataType::F16); },
49 REGISTER_FP16_NEON(arm_compute::cpu::add_mul_add_fp16_neon)},
50 {"neon_qasymm8_add_mul_add", [](const DataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8); },
51 REGISTER_QASYMM8_NEON(arm_compute::cpu::add_mul_add_u8_neon)},
52 {"neon_qasymm8_signed_add_mul_add",
53 [](const DataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8_SIGNED); },
54 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_mul_add_s8_neon)}
Gunes Bayirae72a462023-01-29 13:24:24 +000055#endif // __aarch64__
56};
57
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010058Status validate_arguments(const ITensorInfo *input1,
59 const ITensorInfo *input2,
60 const ITensorInfo *bn_mul,
61 const ITensorInfo *bn_add,
62 const ITensorInfo *add_output,
63 const ITensorInfo *final_output,
64 ConvertPolicy policy,
65 const ActivationLayerInfo &act_info)
Gunes Bayirae72a462023-01-29 13:24:24 +000066{
67 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, bn_mul, bn_add, final_output);
68
69 ARM_COMPUTE_RETURN_ERROR_ON_MSG(policy != ConvertPolicy::SATURATE, "Only Saturate Policy is supported");
70
71 using ActFunction = ActivationLayerInfo::ActivationFunction;
72 const ActFunction act_func = act_info.activation();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010073 ARM_COMPUTE_RETURN_ERROR_ON_MSG((act_func != ActFunction::BOUNDED_RELU && act_func != ActFunction::RELU &&
74 act_func != ActFunction::LU_BOUNDED_RELU && act_func != ActFunction::IDENTITY),
75 "Only RELU Family activations, or no activation, is supported");
Gunes Bayirae72a462023-01-29 13:24:24 +000076
77 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
78 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
79 DataType::F16, DataType::F32);
80 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
81
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010082 if (is_data_type_quantized(input1->data_type()))
Gunes Bayirae72a462023-01-29 13:24:24 +000083 {
84 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bn_mul, 1, DataType::F32);
85 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bn_add, 1, DataType::F32);
86 }
87 else
88 {
89 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, bn_mul);
90 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, bn_add);
91 }
92
93 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); // No broadcasting
94 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mul, bn_add);
95 ARM_COMPUTE_RETURN_ERROR_ON_MSG(bn_mul->num_dimensions() != 1, "BatchNorm coefficients should be 1D array");
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010096 ARM_COMPUTE_RETURN_ERROR_ON_MSG(bn_mul->tensor_shape()[0] != input1->tensor_shape()[0],
97 "First dimensions of inputs and batchNorm coefs should match");
Gunes Bayirae72a462023-01-29 13:24:24 +000098
99 // Validate in case we have add layer's output (intermediate) initialized
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100100 if (add_output != nullptr && add_output->total_size() > 0)
Gunes Bayirae72a462023-01-29 13:24:24 +0000101 {
102 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, add_output);
103 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, add_output);
104 }
105
106 // Validate in case final output has been initialized
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100107 if (final_output->total_size() > 0)
Gunes Bayirae72a462023-01-29 13:24:24 +0000108 {
109 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, final_output);
110 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, final_output);
111 }
112
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100113 const auto uk = CpuAddMulAddKernel::get_implementation<DataTypeISASelectorData>(
114 DataTypeISASelectorData{input1->data_type(), CPUInfo::get().get_isa()});
Gunes Bayirae72a462023-01-29 13:24:24 +0000115 ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
116
117 return Status{};
118}
119} // namespace
120
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100121void CpuAddMulAddKernel::configure(const ITensorInfo *input1,
122 const ITensorInfo *input2,
123 const ITensorInfo *bn_mul,
124 const ITensorInfo *bn_add,
125 ITensorInfo *add_output,
126 ITensorInfo *final_output,
127 ConvertPolicy policy,
128 const ActivationLayerInfo &act_info)
Gunes Bayirae72a462023-01-29 13:24:24 +0000129{
130 ARM_COMPUTE_UNUSED(bn_mul, bn_add, input2);
131 ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, bn_add, bn_mul, final_output);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100132 ARM_COMPUTE_ERROR_THROW_ON(
133 validate_arguments(input1, input2, bn_mul, bn_add, add_output, final_output, policy, act_info));
Gunes Bayirae72a462023-01-29 13:24:24 +0000134
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100135 const auto uk = CpuAddMulAddKernel::get_implementation<DataTypeISASelectorData>(
136 DataTypeISASelectorData{input1->data_type(), CPUInfo::get().get_isa()});
Gunes Bayirae72a462023-01-29 13:24:24 +0000137 ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
138 ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
139
140 _policy = policy;
141 _act_info = act_info;
142 _run_method = uk->ukernel;
143 _name = std::string("CpuAddMulAddKernel/").append(uk->name);
144
145 // Auto initialize outputs if not initialized
146 set_shape_if_empty(*final_output, input1->tensor_shape());
147 set_data_type_if_unknown(*final_output, input1->data_type());
148
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100149 if (add_output != nullptr)
Gunes Bayirae72a462023-01-29 13:24:24 +0000150 {
151 set_shape_if_empty(*add_output, input1->tensor_shape());
152 set_data_type_if_unknown(*add_output, input1->data_type());
153 }
154
155 // Configure kernel window
156 Window win;
157 win = calculate_max_window(*final_output, Steps());
158 ICpuKernel::configure(win);
159}
160
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100161Status CpuAddMulAddKernel::validate(const ITensorInfo *input1,
162 const ITensorInfo *input2,
163 const ITensorInfo *bn_mul,
164 const ITensorInfo *bn_add,
165 const ITensorInfo *add_output,
166 const ITensorInfo *final_output,
167 ConvertPolicy policy,
168 const ActivationLayerInfo &act_info)
Gunes Bayirae72a462023-01-29 13:24:24 +0000169{
170 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, bn_mul, bn_add, final_output);
171
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100172 ARM_COMPUTE_RETURN_ON_ERROR(
173 validate_arguments(input1, input2, bn_mul, bn_add, add_output, final_output, policy, act_info));
Gunes Bayirae72a462023-01-29 13:24:24 +0000174
175 return Status{};
176}
177
178void CpuAddMulAddKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
179{
180 ARM_COMPUTE_UNUSED(info);
181
182 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
183 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
184
185 ARM_COMPUTE_ERROR_ON(tensors.empty());
186 ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
187
188 const ITensor *input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
189 const ITensor *input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
190 const ITensor *bn_mul = tensors.get_const_tensor(TensorType::ACL_SRC_2);
191 const ITensor *bn_add = tensors.get_const_tensor(TensorType::ACL_SRC_3);
192 ITensor *add_output = tensors.get_tensor(TensorType::ACL_DST_0);
193 ITensor *final_output = tensors.get_tensor(TensorType::ACL_DST_1);
194
195 _run_method(input1, input2, bn_mul, bn_add, add_output, final_output, _policy, _act_info, window);
196}
197
198const char *CpuAddMulAddKernel::name() const
199{
200 return _name.c_str();
201}
202
203const std::vector<CpuAddMulAddKernel::AddMulAddKernel> &CpuAddMulAddKernel::get_available_kernels()
204{
205 return available_kernels;
206}
207} // namespace kernels
208} // namespace cpu
209} // namespace arm_compute