blob: 51a69046a90f289557b7db6dcd350891c3c8d3ce [file] [log] [blame]
giuros0115ecc9a2018-12-06 10:47:34 +00001/*
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +02002 * Copyright (c) 2018-2022 Arm Limited.
giuros0115ecc9a2018-12-06 10:47:34 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
Michalis Spyrouaeebe4a2019-01-09 14:21:03 +000017 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
giuros0115ecc9a2018-12-06 10:47:34 +000018 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
Michalis Spyrouaeebe4a2019-01-09 14:21:03 +000019 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
giuros0115ecc9a2018-12-06 10:47:34 +000020 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEFuseBatchNormalizationKernel.h"
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020025#include "src/cpu/kernels/fuse_batch_normalization/list.h"
giuros0115ecc9a2018-12-06 10:47:34 +000026
giuros0115ecc9a2018-12-06 10:47:34 +000027#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/TensorInfo.h"
Georgios Pinitas8f5802f2019-02-22 11:08:32 +000030#include "arm_compute/core/Utils.h"
31#include "arm_compute/core/Validate.h"
32#include "arm_compute/core/Window.h"
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020033#include "src/common/cpuinfo/CpuIsaInfo.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/CPP/Validate.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010035#include "src/core/NEON/wrapper/wrapper.h"
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020036#include "src/core/common/Registrars.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010037#include "src/core/helpers/AutoConfiguration.h"
38#include "src/core/helpers/WindowHelpers.h"
giuros0115ecc9a2018-12-06 10:47:34 +000039
Manuel Bottini11091762019-06-17 12:04:40 +010040#include <map>
41
giuros0115ecc9a2018-12-06 10:47:34 +000042namespace arm_compute
43{
44namespace
45{
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +020046struct FuseBatchNormalizeSelectorData
47{
48 DataType dt;
49 DataLayout dl;
50 FuseBatchNormalizationType fbn_type;
51 cpuinfo::CpuIsaInfo isa;
52};
53
54using FBNSelectorPtr = std::add_pointer<bool(const FuseBatchNormalizeSelectorData &data)>::type;
55using FBNUKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, ITensor *,
56 const ITensor *, const ITensor *, const ITensor *, const ITensor *, float, const Window &)>::type;
57
58struct FBNUKernel
59{
60 const char *name;
61 const FBNSelectorPtr is_selected;
62 FBNUKernelPtr ukernel;
63};
64
65static const FBNUKernel available_kernels[] =
66{
67 {
68 "fused_batch_normalization_conv_NHWC_F16",
69 [](const FuseBatchNormalizeSelectorData & data)
70 {
71 return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
72 },
73 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)
74 },
75 {
76 "fused_batch_normalization_conv_NCHW_F16",
77 [](const FuseBatchNormalizeSelectorData & data)
78 {
79 return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
80 },
81 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_conv_f16)
82 },
83 {
84 "fused_batch_normalization_dwc_NHWC_F16",
85 [](const FuseBatchNormalizeSelectorData & data)
86 {
87 return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
88 },
89 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f16)
90 },
91 {
92 "fused_batch_normalization_dwc_NCHW_F16",
93 [](const FuseBatchNormalizeSelectorData & data)
94 {
95 return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
96 },
97 REGISTER_FP16_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f16)
98 },
99 {
100 "fused_batch_normalization_conv_NHWC_F32",
101 [](const FuseBatchNormalizeSelectorData & data)
102 {
103 return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
104 },
105 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)
106 },
107 {
108 "fused_batch_normalization_conv_NCHW_F32",
109 [](const FuseBatchNormalizeSelectorData & data)
110 {
111 return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
112 },
113 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_conv_f32)
114 },
115 {
116 "fused_batch_normalization_dwc_NHWC_F32",
117 [](const FuseBatchNormalizeSelectorData & data)
118 {
119 return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
120 },
121 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nhwc_f32)
122 },
123 {
124 "fused_batch_normalization_dwc_NCHW_F32",
125 [](const FuseBatchNormalizeSelectorData & data)
126 {
127 return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
128 },
129 REGISTER_FP32_NEON(arm_compute::cpu::fused_batch_normalization_dwc_nchw_f32)
130 }
131};
132
133/** Micro-kernel selector
134 *
135 * @param[in] data Selection data passed to help pick the appropriate micro-kernel
136 *
137 * @param[in]
138 *
139 * @return A matching micro-kernel else nullptr
140 */
141const FBNUKernel *get_implementation(const FuseBatchNormalizeSelectorData &data)
142{
143 for(const auto &uk : available_kernels)
144 {
145 if(uk.is_selected(data))
146 {
147 return &uk;
148 }
149 }
150 return nullptr;
151}
152
Manuel Bottini11091762019-06-17 12:04:40 +0100153Status validate_arguments(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
giuros0115ecc9a2018-12-06 10:47:34 +0000154 const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
Manuel Bottini11091762019-06-17 12:04:40 +0100155 const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
156 float epsilon, FuseBatchNormalizationType fbn_type)
giuros0115ecc9a2018-12-06 10:47:34 +0000157{
158 ARM_COMPUTE_UNUSED(epsilon);
Manuel Bottini11091762019-06-17 12:04:40 +0100159 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
160 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input_weights);
161 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_weights, 1, DataType::F16, DataType::F32);
giuros0115ecc9a2018-12-06 10:47:34 +0000162 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_var);
Manuel Bottini11091762019-06-17 12:04:40 +0100163 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_mean, bn_var);
164 ARM_COMPUTE_RETURN_ERROR_ON(input_bias == nullptr && fused_bias == nullptr);
165 ARM_COMPUTE_RETURN_ERROR_ON(bn_mean->num_dimensions() > 1);
giuros0115ecc9a2018-12-06 10:47:34 +0000166
Manuel Bottini11091762019-06-17 12:04:40 +0100167 if(fbn_type == FuseBatchNormalizationType::CONVOLUTION)
giuros0115ecc9a2018-12-06 10:47:34 +0000168 {
Manuel Bottini11091762019-06-17 12:04:40 +0100169 ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(3) != bn_mean->dimension(0));
170 }
171 else
172 {
173 const size_t channel_idx = get_data_layout_dimension_index(input_weights->data_layout(), DataLayoutDimension::CHANNEL);
174 ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(channel_idx) != bn_mean->dimension(0));
175 }
176 // Validate bias
177 if(input_bias != nullptr)
178 {
179 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, input_bias);
180 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, input_bias);
giuros0115ecc9a2018-12-06 10:47:34 +0000181 }
182 // Validate beta
183 if(bn_beta != nullptr)
184 {
185 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_beta);
Manuel Bottini11091762019-06-17 12:04:40 +0100186 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_beta);
giuros0115ecc9a2018-12-06 10:47:34 +0000187 }
188 // Validate gamma
189 if(bn_gamma != nullptr)
190 {
191 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_gamma);
Manuel Bottini11091762019-06-17 12:04:40 +0100192 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_gamma);
giuros0115ecc9a2018-12-06 10:47:34 +0000193 }
194
195 // Validate output weights
196 if(fused_weights != nullptr && fused_weights->total_size() != 0)
197 {
Manuel Bottini11091762019-06-17 12:04:40 +0100198 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_weights, fused_weights);
199 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input_weights, fused_weights);
200 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_weights);
giuros0115ecc9a2018-12-06 10:47:34 +0000201 }
202 // Validate output bias
203 if(fused_bias != nullptr && fused_bias->total_size() != 0)
204 {
205 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, fused_bias);
Manuel Bottini11091762019-06-17 12:04:40 +0100206 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_bias);
giuros0115ecc9a2018-12-06 10:47:34 +0000207 }
208
209 return Status{};
210}
211
giuros0115ecc9a2018-12-06 10:47:34 +0000212} // namespace
213
214NEFuseBatchNormalizationKernel::NEFuseBatchNormalizationKernel()
Manuel Bottini11091762019-06-17 12:04:40 +0100215 : _input_weights(nullptr), _input_bias(nullptr), _bn_mean(nullptr), _bn_var(nullptr), _bn_gamma(nullptr), _bn_beta(nullptr), _fused_weights(nullptr), _fused_bias(nullptr), _epsilon(),
giuros0115ecc9a2018-12-06 10:47:34 +0000216 _run_in_place_weights(false), _run_in_place_bias(false), _func(nullptr)
217{
218}
219
Manuel Bottini11091762019-06-17 12:04:40 +0100220void NEFuseBatchNormalizationKernel::configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var,
giuros0115ecc9a2018-12-06 10:47:34 +0000221 ITensor *fused_weights, ITensor *fused_bias,
Manuel Bottini11091762019-06-17 12:04:40 +0100222 const ITensor *input_bias, const ITensor *bn_beta, const ITensor *bn_gamma,
223 float epsilon, FuseBatchNormalizationType fbn_type)
giuros0115ecc9a2018-12-06 10:47:34 +0000224{
Manuel Bottini11091762019-06-17 12:04:40 +0100225 ARM_COMPUTE_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
giuros0115ecc9a2018-12-06 10:47:34 +0000226
Manuel Bottini11091762019-06-17 12:04:40 +0100227 _input_weights = input_weights;
228 _input_bias = input_bias;
giuros0115ecc9a2018-12-06 10:47:34 +0000229 _bn_mean = bn_mean;
230 _bn_var = bn_var;
231 _bn_beta = bn_beta;
232 _bn_gamma = bn_gamma;
233 _fused_weights = fused_weights;
234 _fused_bias = fused_bias;
235 _epsilon = epsilon;
236
Manuel Bottini11091762019-06-17 12:04:40 +0100237 _run_in_place_weights = (fused_weights == nullptr) || (fused_weights == input_weights);
238 _run_in_place_bias = (fused_bias == nullptr) || (input_bias != nullptr && fused_bias == input_bias);
giuros0115ecc9a2018-12-06 10:47:34 +0000239
240 // Auto initialize outputs
241 if(_fused_weights != nullptr)
242 {
243 // Output tensor auto initialization if not yet initialized
Manuel Bottini11091762019-06-17 12:04:40 +0100244 auto_init_if_empty(*_fused_weights->info(), *_input_weights->info()->clone());
giuros0115ecc9a2018-12-06 10:47:34 +0000245 }
246 if(_fused_bias != nullptr)
247 {
248 // Output tensor auto initialization if not yet initialized
249 auto_init_if_empty(*_fused_bias->info(), *_bn_mean->info()->clone());
giuros0115ecc9a2018-12-06 10:47:34 +0000250 }
251
252 // Validate arguments
Manuel Bottini11091762019-06-17 12:04:40 +0100253 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input_weights->info(), bn_mean->info(), bn_var->info(),
giuros0115ecc9a2018-12-06 10:47:34 +0000254 (fused_weights != nullptr) ? fused_weights->info() : nullptr,
255 (fused_bias != nullptr) ? fused_bias->info() : nullptr,
Manuel Bottini11091762019-06-17 12:04:40 +0100256 (input_bias != nullptr) ? input_bias->info() : nullptr,
giuros0115ecc9a2018-12-06 10:47:34 +0000257 (bn_beta != nullptr) ? bn_beta->info() : nullptr,
258 (bn_gamma != nullptr) ? bn_gamma->info() : nullptr,
Manuel Bottini11091762019-06-17 12:04:40 +0100259 epsilon, fbn_type));
giuros0115ecc9a2018-12-06 10:47:34 +0000260
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +0200261 const auto *uk = get_implementation(FuseBatchNormalizeSelectorData{ input_weights->info()->data_type(), input_weights->info()->data_layout(), fbn_type, CPUInfo::get().get_isa() });
262 ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
263 ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
264 _func = uk->ukernel;
265
giuros0115ecc9a2018-12-06 10:47:34 +0000266 // Configure kernel window
Manuel Bottini11091762019-06-17 12:04:40 +0100267 Window win = calculate_max_window(*input_weights->info());
giuros0115ecc9a2018-12-06 10:47:34 +0000268 INEKernel::configure(win);
giuros0115ecc9a2018-12-06 10:47:34 +0000269}
270
Manuel Bottini11091762019-06-17 12:04:40 +0100271Status NEFuseBatchNormalizationKernel::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
giuros0115ecc9a2018-12-06 10:47:34 +0000272 const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
Manuel Bottini11091762019-06-17 12:04:40 +0100273 const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
274 float epsilon, FuseBatchNormalizationType fbn_type)
giuros0115ecc9a2018-12-06 10:47:34 +0000275{
Manuel Bottini11091762019-06-17 12:04:40 +0100276 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type));
giuros0115ecc9a2018-12-06 10:47:34 +0000277 return Status{};
278}
279
280void NEFuseBatchNormalizationKernel::run(const Window &window, const ThreadInfo &info)
281{
282 ARM_COMPUTE_UNUSED(info);
283 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
284 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
Yair Schwarzbaum41a729e2021-11-15 20:42:47 +0200285
286 ARM_COMPUTE_ERROR_ON(_func == nullptr);
Manuel Bottini11091762019-06-17 12:04:40 +0100287 (*_func)(_input_weights, _input_bias, _fused_weights, _fused_bias, _bn_mean, _bn_var, _bn_beta, _bn_gamma, _epsilon, window);
giuros0115ecc9a2018-12-06 10:47:34 +0000288}
289} // namespace arm_compute