blob: 717fd114857f2ebbf2fbeb3d998128761d1ce94b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +00002 * Copyright (c) 2017-2021, 2023 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Utils.h"
29#include "arm_compute/core/Validate.h"
30#include "arm_compute/core/Window.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010031
32#include "src/core/common/Registrars.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/helpers/AutoConfiguration.h"
35#include "src/core/helpers/WindowHelpers.h"
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000036#include "src/core/NEON/kernels/batchnormalization/impl/list.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010037#include "src/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
38#include "src/core/NEON/NEFixedPoint.h"
39#include "src/core/NEON/NEMath.h"
40#include "src/core/NEON/wrapper/wrapper.h"
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000041
Georgios Pinitas57c033b2018-02-15 12:29:44 +000042#include <map>
43
Georgios Pinitas980a9162020-06-03 20:16:46 +010044namespace arm_compute
45{
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000046namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047{
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000048struct BatchNormalizationSelectorData
49{
Michalis Spyrou20fca522021-06-07 14:23:57 +010050 DataType dt;
51 const CPUInfo &ci;
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000052};
53using BatchNormalizationSelectorPtr = std::add_pointer<bool(const BatchNormalizationSelectorData &data)>::type;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010054using BatchNormalizationKernelPtr = std::add_pointer<void(ITensor *,
55 ITensor *,
56 const ITensor *,
57 const ITensor *,
58 const ITensor *,
59 const ITensor *,
60 float,
61 ActivationLayerInfo &,
62 const Window &)>::type;
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000063
64struct BatchNormalizationKernel
65{
66 const char *name;
67 const BatchNormalizationSelectorPtr is_selected;
68 BatchNormalizationKernelPtr ukernel;
69};
70
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010071static const BatchNormalizationKernel available_kernels[] = {
Michalis Spyrou20fca522021-06-07 14:23:57 +010072#if defined(ARM_COMPUTE_ENABLE_SVE)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010073 {"sve_fp16_batch_normalization",
74 [](const BatchNormalizationSelectorData &data) { return data.dt == DataType::F16 && data.ci.has_sve(); },
75 REGISTER_FP16_SVE(arm_compute::cpu::fp16_sve_batch_normalization)},
76 {"sve_fp32_batch_normalization",
77 [](const BatchNormalizationSelectorData &data) { return data.dt == DataType::F32 && data.ci.has_sve(); },
78 REGISTER_FP32_SVE(arm_compute::cpu::fp32_sve_batch_normalization)},
Michalis Spyrou20fca522021-06-07 14:23:57 +010079#endif /* !defined(ARM_COMPUTE_ENABLE_SVE) */
80#if defined(ARM_COMPUTE_ENABLE_NEON)
Sheri Zhang97b3f112021-01-04 17:14:23 +000081#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010082 {"neon_fp16_batch_normalization",
83 [](const BatchNormalizationSelectorData &data) { return data.dt == DataType::F16; },
84 REGISTER_FP16_NEON(arm_compute::cpu::fp16_neon_batch_normalization)},
Sheri Zhang97b3f112021-01-04 17:14:23 +000085#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010086 {"neon_fp32_batch_normalization",
87 [](const BatchNormalizationSelectorData &data) { return data.dt == DataType::F32; },
88 REGISTER_FP32_NEON(arm_compute::cpu::fp32_neon_batch_normalization)},
Michalis Spyrou20fca522021-06-07 14:23:57 +010089#endif /* !defined(ARM_COMPUTE_ENABLE_NEON) */
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000090};
91
92const BatchNormalizationKernel *get_implementation(const BatchNormalizationSelectorData &data)
93{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010094 for (const auto &uk : available_kernels)
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000095 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010096 if (uk.is_selected(data))
Sheri Zhang8d5d78b2020-12-15 20:25:31 +000097 {
98 return &uk;
99 }
100 }
101 return nullptr;
102}
103
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100104Status validate_arguments(const ITensorInfo *input,
105 const ITensorInfo *output,
106 const ITensorInfo *mean,
107 const ITensorInfo *var,
108 const ITensorInfo *beta,
109 const ITensorInfo *gamma,
110 float epsilon,
111 ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000112{
113 ARM_COMPUTE_UNUSED(epsilon);
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000114
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100115 const auto *uk = get_implementation(BatchNormalizationSelectorData{input->data_type(), CPUInfo::get()});
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000116 ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000117
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100118 if (act_info.enabled())
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000119 {
120 ActivationLayerInfo::ActivationFunction act = act_info.activation();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100121 ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU &&
122 act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU &&
123 act !=
124 ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000125 ARM_COMPUTE_RETURN_ERROR_ON(act_info.b() > act_info.a());
126 }
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000127
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100128 if (nullptr != output)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000129 {
130 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000131 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000132 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000133 }
134
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000135 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000136 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100137 if (beta != nullptr)
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000138 {
139 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000140 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
141 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100142 if (gamma != nullptr)
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000143 {
144 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000145 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
146 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100147 ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(
148 input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000149
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000150 return Status{};
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000151}
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000152} //namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100153
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000154void NEBatchNormalizationLayerKernel::configure_non_fused()
155{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100156 switch (_input->info()->data_type())
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000157 {
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000158 case DataType::F16:
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000159 _func = REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000160 break;
161 case DataType::F32:
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000162 _func = REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000163 break;
164 default:
165 ARM_COMPUTE_ERROR("Element size not supported");
166 break;
167 }
168}
169
170void NEBatchNormalizationLayerKernel::configure_fused()
171{
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000172 // NCHW Fused Batched Normalization with activation functions : FP32
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100173 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nchw = {
174 {ActivationLayerInfo::ActivationFunction::RELU,
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000175 REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused_relu)},
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100176 {ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000177 REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused_brelu)},
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100178 {ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000179 REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused_lubrelu)}};
180
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100181 // NCHW Fused Batched Normalization with activation functions : FP16
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100182 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nchw = {
183 {ActivationLayerInfo::ActivationFunction::RELU,
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000184 REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused_relu)},
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100185 {ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000186 REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused_brelu)},
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100187 {ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000188 REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused_lubrelu)}};
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000189
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100190 switch (_input->info()->data_type())
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000191 {
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100192 case DataType::F16:
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000193 _func = bn_fused_map_f16_nchw[_act_info.activation()];
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100194 break;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000195 case DataType::F32:
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000196 _func = bn_fused_map_f32_nchw[_act_info.activation()];
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000197 break;
198 default:
199 ARM_COMPUTE_ERROR("Element size not supported");
200 break;
201 }
202}
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000203
204NEBatchNormalizationLayerKernel::NEBatchNormalizationLayerKernel()
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100205 : _func(nullptr),
206 _input(nullptr),
207 _output(nullptr),
208 _mean(nullptr),
209 _var(nullptr),
210 _gamma(nullptr),
211 _beta(nullptr),
212 _epsilon(),
213 _act_info()
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000214{
215}
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100216
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100217void NEBatchNormalizationLayerKernel::configure(ITensor *input,
218 ITensor *output,
219 const ITensor *mean,
220 const ITensor *var,
221 const ITensor *beta,
222 const ITensor *gamma,
223 float epsilon,
224 ActivationLayerInfo act_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100225{
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000226 ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000227
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000228 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100229 mean->info(), var->info(), (beta != nullptr) ? beta->info() : nullptr,
230 (gamma != nullptr) ? gamma->info() : nullptr, epsilon, act_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000232 _input = input;
233 _output = input;
234 _mean = mean;
235 _var = var;
236 _gamma = gamma;
237 _beta = beta;
238 _epsilon = epsilon;
239 _act_info = act_info;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100240
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000241 const bool run_in_place = (output == nullptr) || (output == input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100242 if (!run_in_place)
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100243 {
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100244 _output = output;
245 }
246
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000247 // Configure activation function to run
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000248 const bool is_nchw = _input->info()->data_layout() == DataLayout::NCHW;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100249 if (is_nchw)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100250 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100251 if (_act_info.enabled())
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000252 {
253 configure_fused();
254 }
255 else
256 {
257 configure_non_fused();
258 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100259 }
260
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000261 // Configure kernel window
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000262 Window win = calculate_max_window(*input->info(), Steps());
263 INEKernel::configure(win);
264
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100265 if (output != nullptr)
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000266 {
267 // Output auto initialization if not yet initialized
268 auto_init_if_empty(*output->info(), *input->info()->clone());
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000269 }
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000270}
271
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100272Status NEBatchNormalizationLayerKernel::validate(const ITensorInfo *input,
273 const ITensorInfo *output,
274 const ITensorInfo *mean,
275 const ITensorInfo *var,
276 const ITensorInfo *beta,
277 const ITensorInfo *gamma,
278 float epsilon,
279 ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000280{
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000281 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon, act_info));
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000282
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000283 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100284}
285
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100286void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100287{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100288 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100289 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
290 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000291 ARM_COMPUTE_ERROR_ON(_func == nullptr && _input->info()->data_layout() == DataLayout::NCHW);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100292
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000293 const bool is_nchw = _input->info()->data_layout() == DataLayout::NCHW;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100294 if (is_nchw)
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000295 {
Pablo Marquez Tello8d4cdd42023-11-21 10:10:01 +0000296 (*_func)(window, _input, _output, _mean, _var, _beta, _gamma, _epsilon, _act_info);
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000297 }
298 else
299 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100300 const auto *uk =
301 get_implementation(BatchNormalizationSelectorData{_input->info()->data_type(), CPUInfo::get()});
Sheri Zhang8d5d78b2020-12-15 20:25:31 +0000302 uk->ukernel(_input, _output, _mean, _var, _beta, _gamma, _epsilon, _act_info, window);
303 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100304}
Georgios Pinitas980a9162020-06-03 20:16:46 +0100305} // namespace arm_compute