blob: cbe5136fb1b22e845f0b9b8ee2f481289a478786 [file] [log] [blame]
giuros0115ecc9a2018-12-06 10:47:34 +00001/*
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +02002 * Copyright (c) 2018-2022 Arm Limited.
giuros0115ecc9a2018-12-06 10:47:34 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
Michalis Spyrouaeebe4a2019-01-09 14:21:03 +000017 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
giuros0115ecc9a2018-12-06 10:47:34 +000018 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
Michalis Spyrouaeebe4a2019-01-09 14:21:03 +000019 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
giuros0115ecc9a2018-12-06 10:47:34 +000020 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEFuseBatchNormalizationKernel.h"
giuros0115ecc9a2018-12-06 10:47:34 +000025
giuros0115ecc9a2018-12-06 10:47:34 +000026#include "arm_compute/core/Helpers.h"
27#include "arm_compute/core/ITensor.h"
28#include "arm_compute/core/TensorInfo.h"
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000029#include "arm_compute/core/Utils.h"
30#include "arm_compute/core/Validate.h"
31#include "arm_compute/core/Window.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010032
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020033#include "src/common/cpuinfo/CpuIsaInfo.h"
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020034#include "src/core/common/Registrars.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010035#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010036#include "src/core/helpers/AutoConfiguration.h"
37#include "src/core/helpers/WindowHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010038#include "src/core/NEON/wrapper/wrapper.h"
39#include "src/cpu/kernels/fuse_batch_normalization/list.h"
giuros0115ecc9a2018-12-06 10:47:34 +000040
Manuel Bottini11091762019-06-17 12:04:40 +010041#include <map>
42
giuros0115ecc9a2018-12-06 10:47:34 +000043namespace arm_compute
44{
45namespace
46{
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020047struct FuseBatchNormalizeSelectorData
48{
49 DataType dt;
50 DataLayout dl;
51 FuseBatchNormalizationType fbn_type;
52 cpuinfo::CpuIsaInfo isa;
53};
54
55using FBNSelectorPtr = std::add_pointer<bool(const FuseBatchNormalizeSelectorData &data)>::type;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010056using FBNUKernelPtr = std::add_pointer<void(const ITensor *,
57 const ITensor *,
58 ITensor *,
59 ITensor *,
60 const ITensor *,
61 const ITensor *,
62 const ITensor *,
63 const ITensor *,
64 float,
65 const Window &)>::type;
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020066
67struct FBNUKernel
68{
69 const char *name;
70 const FBNSelectorPtr is_selected;
71 FBNUKernelPtr ukernel;
72};
73
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010074static const FBNUKernel available_kernels[] = {
75 {"fused_batch_normalization_conv_NHWC_F16",
76 [](const FuseBatchNormalizeSelectorData &data)
77 {
78 return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 &&
79 data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
80 },
81 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)},
82 {"fused_batch_normalization_conv_NCHW_F16",
83 [](const FuseBatchNormalizeSelectorData &data)
84 {
85 return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 &&
86 data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
87 },
88 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)},
89 {"fused_batch_normalization_dwc_NHWC_F16",
90 [](const FuseBatchNormalizeSelectorData &data)
91 {
92 return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 &&
93 data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
94 },
95 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f16)},
96 {"fused_batch_normalization_dwc_NCHW_F16",
97 [](const FuseBatchNormalizeSelectorData &data)
98 {
99 return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 &&
100 data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
101 },
102 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f16)},
103 {"fused_batch_normalization_conv_NHWC_F32",
104 [](const FuseBatchNormalizeSelectorData &data)
105 {
106 return data.dt == DataType::F32 && data.dl == DataLayout::NHWC &&
107 data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
108 },
109 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)},
110 {"fused_batch_normalization_conv_NCHW_F32",
111 [](const FuseBatchNormalizeSelectorData &data)
112 {
113 return data.dt == DataType::F32 && data.dl == DataLayout::NCHW &&
114 data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
115 },
116 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)},
117 {"fused_batch_normalization_dwc_NHWC_F32",
118 [](const FuseBatchNormalizeSelectorData &data)
119 {
120 return data.dt == DataType::F32 && data.dl == DataLayout::NHWC &&
121 data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
122 },
123 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f32)},
124 {"fused_batch_normalization_dwc_NCHW_F32",
125 [](const FuseBatchNormalizeSelectorData &data)
126 {
127 return data.dt == DataType::F32 && data.dl == DataLayout::NCHW &&
128 data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
129 },
130 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f32)}};
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +0200131
132/** Micro-kernel selector
133 *
134 * @param[in] data Selection data passed to help pick the appropriate micro-kernel
135 *
136 * @param[in]
137 *
138 * @return A matching micro-kernel else nullptr
139 */
140const FBNUKernel *get_implementation(const FuseBatchNormalizeSelectorData &data)
141{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100142 for (const auto &uk : available_kernels)
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +0200143 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100144 if (uk.is_selected(data))
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +0200145 {
146 return &uk;
147 }
148 }
149 return nullptr;
150}
151
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100152Status validate_arguments(const ITensorInfo *input_weights,
153 const ITensorInfo *bn_mean,
154 const ITensorInfo *bn_var,
155 const ITensorInfo *fused_weights,
156 const ITensorInfo *fused_bias,
157 const ITensorInfo *input_bias,
158 const ITensorInfo *bn_beta,
159 const ITensorInfo *bn_gamma,
160 float epsilon,
161 FuseBatchNormalizationType fbn_type)
giuros0115ecc9a2018-12-06 10:47:34 +0000162{
163 ARM_COMPUTE_UNUSED(epsilon);
Manuel Bottini11091762019-06-17 12:04:40 +0100164 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
165 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input_weights);
166 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_weights, 1, DataType::F16, DataType::F32);
giuros0115ecc9a2018-12-06 10:47:34 +0000167 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_var);
Manuel Bottini11091762019-06-17 12:04:40 +0100168 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_mean, bn_var);
169 ARM_COMPUTE_RETURN_ERROR_ON(input_bias == nullptr && fused_bias == nullptr);
170 ARM_COMPUTE_RETURN_ERROR_ON(bn_mean->num_dimensions() > 1);
giuros0115ecc9a2018-12-06 10:47:34 +0000171
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100172 if (fbn_type == FuseBatchNormalizationType::CONVOLUTION)
giuros0115ecc9a2018-12-06 10:47:34 +0000173 {
Manuel Bottini11091762019-06-17 12:04:40 +0100174 ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(3) != bn_mean->dimension(0));
175 }
176 else
177 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100178 const size_t channel_idx =
179 get_data_layout_dimension_index(input_weights->data_layout(), DataLayoutDimension::CHANNEL);
Manuel Bottini11091762019-06-17 12:04:40 +0100180 ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(channel_idx) != bn_mean->dimension(0));
181 }
182 // Validate bias
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100183 if (input_bias != nullptr)
Manuel Bottini11091762019-06-17 12:04:40 +0100184 {
185 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, input_bias);
186 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, input_bias);
giuros0115ecc9a2018-12-06 10:47:34 +0000187 }
188 // Validate beta
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100189 if (bn_beta != nullptr)
giuros0115ecc9a2018-12-06 10:47:34 +0000190 {
191 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_beta);
Manuel Bottini11091762019-06-17 12:04:40 +0100192 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_beta);
giuros0115ecc9a2018-12-06 10:47:34 +0000193 }
194 // Validate gamma
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100195 if (bn_gamma != nullptr)
giuros0115ecc9a2018-12-06 10:47:34 +0000196 {
197 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_gamma);
Manuel Bottini11091762019-06-17 12:04:40 +0100198 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_gamma);
giuros0115ecc9a2018-12-06 10:47:34 +0000199 }
200
201 // Validate output weights
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100202 if (fused_weights != nullptr && fused_weights->total_size() != 0)
giuros0115ecc9a2018-12-06 10:47:34 +0000203 {
Manuel Bottini11091762019-06-17 12:04:40 +0100204 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_weights, fused_weights);
205 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input_weights, fused_weights);
206 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_weights);
giuros0115ecc9a2018-12-06 10:47:34 +0000207 }
208 // Validate output bias
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100209 if (fused_bias != nullptr && fused_bias->total_size() != 0)
giuros0115ecc9a2018-12-06 10:47:34 +0000210 {
211 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, fused_bias);
Manuel Bottini11091762019-06-17 12:04:40 +0100212 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_bias);
giuros0115ecc9a2018-12-06 10:47:34 +0000213 }
214
215 return Status{};
216}
217
giuros0115ecc9a2018-12-06 10:47:34 +0000218} // namespace
219
220NEFuseBatchNormalizationKernel::NEFuseBatchNormalizationKernel()
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100221 : _input_weights(nullptr),
222 _input_bias(nullptr),
223 _bn_mean(nullptr),
224 _bn_var(nullptr),
225 _bn_gamma(nullptr),
226 _bn_beta(nullptr),
227 _fused_weights(nullptr),
228 _fused_bias(nullptr),
229 _epsilon(),
230 _run_in_place_weights(false),
231 _run_in_place_bias(false),
232 _func(nullptr)
giuros0115ecc9a2018-12-06 10:47:34 +0000233{
234}
235
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100236void NEFuseBatchNormalizationKernel::configure(const ITensor *input_weights,
237 const ITensor *bn_mean,
238 const ITensor *bn_var,
239 ITensor *fused_weights,
240 ITensor *fused_bias,
241 const ITensor *input_bias,
242 const ITensor *bn_beta,
243 const ITensor *bn_gamma,
244 float epsilon,
245 FuseBatchNormalizationType fbn_type)
giuros0115ecc9a2018-12-06 10:47:34 +0000246{
Manuel Bottini11091762019-06-17 12:04:40 +0100247 ARM_COMPUTE_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
giuros0115ecc9a2018-12-06 10:47:34 +0000248
Manuel Bottini11091762019-06-17 12:04:40 +0100249 _input_weights = input_weights;
250 _input_bias = input_bias;
giuros0115ecc9a2018-12-06 10:47:34 +0000251 _bn_mean = bn_mean;
252 _bn_var = bn_var;
253 _bn_beta = bn_beta;
254 _bn_gamma = bn_gamma;
255 _fused_weights = fused_weights;
256 _fused_bias = fused_bias;
257 _epsilon = epsilon;
258
Manuel Bottini11091762019-06-17 12:04:40 +0100259 _run_in_place_weights = (fused_weights == nullptr) || (fused_weights == input_weights);
260 _run_in_place_bias = (fused_bias == nullptr) || (input_bias != nullptr && fused_bias == input_bias);
giuros0115ecc9a2018-12-06 10:47:34 +0000261
262 // Auto initialize outputs
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100263 if (_fused_weights != nullptr)
giuros0115ecc9a2018-12-06 10:47:34 +0000264 {
265 // Output tensor auto initialization if not yet initialized
Manuel Bottini11091762019-06-17 12:04:40 +0100266 auto_init_if_empty(*_fused_weights->info(), *_input_weights->info()->clone());
giuros0115ecc9a2018-12-06 10:47:34 +0000267 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100268 if (_fused_bias != nullptr)
giuros0115ecc9a2018-12-06 10:47:34 +0000269 {
270 // Output tensor auto initialization if not yet initialized
271 auto_init_if_empty(*_fused_bias->info(), *_bn_mean->info()->clone());
giuros0115ecc9a2018-12-06 10:47:34 +0000272 }
273
274 // Validate arguments
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100275 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(
276 input_weights->info(), bn_mean->info(), bn_var->info(),
277 (fused_weights != nullptr) ? fused_weights->info() : nullptr,
278 (fused_bias != nullptr) ? fused_bias->info() : nullptr, (input_bias != nullptr) ? input_bias->info() : nullptr,
279 (bn_beta != nullptr) ? bn_beta->info() : nullptr, (bn_gamma != nullptr) ? bn_gamma->info() : nullptr, epsilon,
280 fbn_type));
giuros0115ecc9a2018-12-06 10:47:34 +0000281
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100282 const auto *uk = get_implementation(FuseBatchNormalizeSelectorData{
283 input_weights->info()->data_type(), input_weights->info()->data_layout(), fbn_type, CPUInfo::get().get_isa()});
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +0200284 ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
285 ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
286 _func = uk->ukernel;
287
giuros0115ecc9a2018-12-06 10:47:34 +0000288 // Configure kernel window
Manuel Bottini11091762019-06-17 12:04:40 +0100289 Window win = calculate_max_window(*input_weights->info());
giuros0115ecc9a2018-12-06 10:47:34 +0000290 INEKernel::configure(win);
giuros0115ecc9a2018-12-06 10:47:34 +0000291}
292
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100293Status NEFuseBatchNormalizationKernel::validate(const ITensorInfo *input_weights,
294 const ITensorInfo *bn_mean,
295 const ITensorInfo *bn_var,
296 const ITensorInfo *fused_weights,
297 const ITensorInfo *fused_bias,
298 const ITensorInfo *input_bias,
299 const ITensorInfo *bn_beta,
300 const ITensorInfo *bn_gamma,
301 float epsilon,
302 FuseBatchNormalizationType fbn_type)
giuros0115ecc9a2018-12-06 10:47:34 +0000303{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100304 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input_weights, bn_mean, bn_var, fused_weights, fused_bias,
305 input_bias, bn_beta, bn_gamma, epsilon, fbn_type));
giuros0115ecc9a2018-12-06 10:47:34 +0000306 return Status{};
307}
308
309void NEFuseBatchNormalizationKernel::run(const Window &window, const ThreadInfo &info)
310{
311 ARM_COMPUTE_UNUSED(info);
312 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
313 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +0200314
315 ARM_COMPUTE_ERROR_ON(_func == nullptr);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100316 (*_func)(_input_weights, _input_bias, _fused_weights, _fused_bias, _bn_mean, _bn_var, _bn_beta, _bn_gamma, _epsilon,
317 window);
giuros0115ecc9a2018-12-06 10:47:34 +0000318}
319} // namespace arm_compute