blob: 0651cf28e6dd4beabdadc67d973ca46b02631c39 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/NEON/NEFixedPoint.h"
29#include "arm_compute/core/NEON/NEMath.h"
Georgios Pinitas57c033b2018-02-15 12:29:44 +000030#include "arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/TensorInfo.h"
32#include "arm_compute/core/Utils.h"
33#include "arm_compute/core/Validate.h"
34#include "arm_compute/core/Window.h"
35
Georgios Pinitas980a9162020-06-03 20:16:46 +010036#include "arm_compute/core/NEON/wrapper/wrapper.h"
37
Georgios Pinitas57c033b2018-02-15 12:29:44 +000038#include <map>
39
Georgios Pinitas980a9162020-06-03 20:16:46 +010040namespace arm_compute
41{
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000042namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043{
Georgios Pinitas57c033b2018-02-15 12:29:44 +000044Status
45validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *mean, const ITensorInfo *var,
46 const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon, ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000047{
48 ARM_COMPUTE_UNUSED(epsilon);
Anthony Barbiereaefd002018-07-20 17:49:35 +010049 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Georgios Pinitasaaba4c62018-08-22 16:20:21 +010050 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Georgios Pinitas57c033b2018-02-15 12:29:44 +000051
52 if(act_info.enabled())
53 {
54 ActivationLayerInfo::ActivationFunction act = act_info.activation();
Georgios Pinitas6f109bd2018-07-16 12:57:42 +010055 ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU
56 && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
Georgios Pinitas57c033b2018-02-15 12:29:44 +000057 && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
58 ARM_COMPUTE_RETURN_ERROR_ON(act_info.b() > act_info.a());
59 }
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000060
61 if(nullptr != output)
62 {
63 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +000064 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000065 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000066 }
67
Michele Di Giorgio4d336302018-03-02 09:43:54 +000068 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
Michele Di Giorgio4d336302018-03-02 09:43:54 +000069 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
70 if(beta != nullptr)
71 {
72 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
Michele Di Giorgio4d336302018-03-02 09:43:54 +000073 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
74 }
75 if(gamma != nullptr)
76 {
77 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
Michele Di Giorgio4d336302018-03-02 09:43:54 +000078 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
79 }
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +000080 ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000081
Georgios Pinitas631c41a2017-12-06 11:53:03 +000082 return Status{};
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000083}
84
Giorgio Arena47463262019-08-05 17:15:40 +010085std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *mean, ITensorInfo *var, ITensorInfo *gamma, ITensorInfo *beta)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000086{
Georgios Pinitas980a9162020-06-03 20:16:46 +010087 ARM_COMPUTE_UNUSED(mean, var, gamma, beta);
88
89 // Configure kernel window
90 Window win = calculate_max_window(*input, Steps());
91
Michele Di Giorgio4d336302018-03-02 09:43:54 +000092 if(output != nullptr)
93 {
Georgios Pinitas980a9162020-06-03 20:16:46 +010094 // Output auto initialization if not yet initialized
Michele Di Giorgio4d336302018-03-02 09:43:54 +000095 auto_init_if_empty(*output, *input->clone());
Georgios Pinitas980a9162020-06-03 20:16:46 +010096
97 // NEBatchNormalizationLayerKernel doesn't need padding so update_window_and_padding() can be skipped
98 Coordinates coord;
99 coord.set_num_dimensions(output->num_dimensions());
100 output->set_valid_region(ValidRegion(coord, output->tensor_shape()));
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000101 }
102
Georgios Pinitas980a9162020-06-03 20:16:46 +0100103 return std::make_pair(Status{}, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100104}
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000105} //namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100106
Georgios Pinitas980a9162020-06-03 20:16:46 +0100107template <typename T, bool fused_activation, typename F>
108void NEBatchNormalizationLayerKernel::batch_normalization_nchw(const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100109{
Georgios Pinitas980a9162020-06-03 20:16:46 +0100110 /** NEON vector tag type. */
111 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
112
113 const int window_step_x = 16 / sizeof(T);
114 const auto window_start_x = static_cast<int>(window.x().start());
115 const auto window_end_x = static_cast<int>(window.x().end());
116
117 Window win_to_use = window;
118 win_to_use.set(Window::DimX, Window::Dimension(0, 1, 1));
119
120 Iterator input(_input, win_to_use);
121 Iterator output(_output, win_to_use);
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100122
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100123 F activation_functor(_act_info);
124
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100125 // Hold information about the current feature map we are iterating.
126 // Only compute denominator and NEON vectors once per feature map.
127 int slice = -1;
128
Georgios Pinitas980a9162020-06-03 20:16:46 +0100129 const auto input_mean = reinterpret_cast<const T *>(_mean->ptr_to_element(Coordinates(0, 0)));
130 const auto input_var = reinterpret_cast<const T *>(_var->ptr_to_element(Coordinates(0, 0)));
131 const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const T *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
132 const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const T *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100133
Georgios Pinitas980a9162020-06-03 20:16:46 +0100134 T mean = static_cast<T>(0);
135 T var = static_cast<T>(0);
136 T gamma = static_cast<T>(1);
137 T beta = static_cast<T>(0);
138 T denominator = static_cast<T>(0);
139
140 auto mean_vec = wrapper::vdup_n(mean, ExactTagType{});
141 auto var_vec = wrapper::vdup_n(var, ExactTagType{});
142 auto gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
143 auto beta_vec = wrapper::vdup_n(beta, ExactTagType{});
144 auto denominator_vec = wrapper::vdup_n(denominator, ExactTagType{});
145 const auto epsilon_vec = wrapper::vdup_n(static_cast<T>(_epsilon), ExactTagType{});
146 execute_window_loop(win_to_use, [&](const Coordinates & id)
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100147 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100148 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
149 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
150
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100151 if(slice != id.z())
152 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100153 mean = input_mean[id.z()];
154 var = input_var[id.z()];
155 mean_vec = wrapper::vdup_n(mean, ExactTagType{});
156 var_vec = wrapper::vdup_n(var, ExactTagType{});
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000157 if(input_gamma != nullptr)
158 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100159 gamma = input_gamma[id.z()];
160 gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000161 }
162 if(input_beta != nullptr)
163 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100164 beta = input_beta[id.z()];
165 beta_vec = wrapper::vdup_n(beta, ExactTagType{});
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000166 }
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100167
168 // Calculate denominator
Georgios Pinitas980a9162020-06-03 20:16:46 +0100169 denominator_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
170 denominator = wrapper::vgetlane(denominator_vec, 0);
171 slice = id.z();
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100172 }
173
Georgios Pinitas980a9162020-06-03 20:16:46 +0100174 // Perform core calculations using vector operations
175 int x = window_start_x;
176 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100177 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100178 // Calculate x bar
179 const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec);
180 const auto x_bar = wrapper::vmul(numerator, denominator_vec);
181 auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec);
182
183 // Perform fused activation
184 if(fused_activation)
185 {
186 activation_functor(res);
187 }
188
189 // Store results
190 wrapper::vstore(output_ptr + x, res);
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100191 }
192
Georgios Pinitas980a9162020-06-03 20:16:46 +0100193 // Compute left-over elements
194 for(; x < window_end_x; ++x)
195 {
196 const T numerator = input_ptr[x] - mean;
197 const T x_bar = numerator * denominator;
198 T res = beta + x_bar * gamma;
199
200 // Perform fused activation
201 if(fused_activation)
202 {
203 activation_functor(res);
204 }
205
206 // Store results
207 *(output_ptr + x) = res;
208 }
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100209 },
210 input, output);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000211}
212
Georgios Pinitas980a9162020-06-03 20:16:46 +0100213template <typename T, bool fused_activation, typename F>
214void NEBatchNormalizationLayerKernel::batch_normalization_nhwc(const Window &window)
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000215{
Georgios Pinitas980a9162020-06-03 20:16:46 +0100216 /** NEON vector tag type. */
217 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
218
219 const int window_step_x = 16 / sizeof(T);
220 const auto window_start_x = static_cast<int>(window.x().start());
221 const auto window_end_x = static_cast<int>(window.x().end());
222
223 Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
224 win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
225
226 Iterator input(_input, win_collapsed);
227 Iterator output(_output, win_collapsed);
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000228
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100229 F activation_functor(_act_info);
230
Georgios Pinitas980a9162020-06-03 20:16:46 +0100231 const auto input_mean = reinterpret_cast<const T *>(_mean->ptr_to_element(Coordinates(0, 0)));
232 const auto input_var = reinterpret_cast<const T *>(_var->ptr_to_element(Coordinates(0, 0)));
233 const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const T *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
234 const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const T *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000235
Georgios Pinitas980a9162020-06-03 20:16:46 +0100236 const auto epsilon_vec = wrapper::vdup_n(static_cast<T>(_epsilon), ExactTagType{});
237 execute_window_loop(win_collapsed, [&](const Coordinates &)
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000238 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100239 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
240 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000241
Georgios Pinitas980a9162020-06-03 20:16:46 +0100242 // Perform core calculations using vector operations
243 int x = window_start_x;
244 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000245 {
246 // Conctruct vectors
Georgios Pinitas980a9162020-06-03 20:16:46 +0100247 const auto mean_vec = wrapper::vloadq(input_mean + x);
248 const auto var_vec = wrapper::vloadq(input_var + x);
249 const auto gamma_vec = (input_gamma != nullptr) ? wrapper::vloadq(input_gamma + x) : wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
250 const auto beta_vec = (input_beta != nullptr) ? wrapper::vloadq(input_beta + x) : wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000251
252 // Calculate denominator
Georgios Pinitas980a9162020-06-03 20:16:46 +0100253 const auto denominator = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
254
255 // Calculate x bar
256 const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec);
257 const auto x_bar = wrapper::vmul(numerator, denominator);
258 auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec);
259
260 // Perform fused activation
261 if(fused_activation)
262 {
263 activation_functor(res);
264 }
265
266 // Store results
267 wrapper::vstore(output_ptr + x, res);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000268 }
269
Georgios Pinitas980a9162020-06-03 20:16:46 +0100270 // Compute left-over elements
271 for(; x < window_end_x; ++x)
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000272 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100273 // Conctruct vectors
274 const T gamma = (input_gamma != nullptr) ? input_gamma[x] : 1.f;
275 const T beta = (input_beta != nullptr) ? input_beta[x] : 0.f;
276
277 const T denominator = sqrt(input_var[x] + _epsilon);
278 const T numerator = input_ptr[x] - input_mean[x];
279 const T x_bar = numerator / denominator;
280 T res = beta + x_bar * gamma;
281
282 // Perform fused activation
283 if(fused_activation)
284 {
285 activation_functor(res);
286 }
287
288 // Store results
289 *reinterpret_cast<T *>(output_ptr + x) = res;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000290 }
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000291 },
292 input, output);
293}
294
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000295void NEBatchNormalizationLayerKernel::configure_non_fused()
296{
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000297 const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000298 switch(_input->info()->data_type())
299 {
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100300#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000301 case DataType::F16:
Georgios Pinitas980a9162020-06-03 20:16:46 +0100302 _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, false, detail::dummy<float16_t, 8>> :
303 &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, false, detail::dummy<float16_t, 8>>;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000304 break;
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100305#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000306 case DataType::F32:
Georgios Pinitas980a9162020-06-03 20:16:46 +0100307 _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, false, detail::dummy<float, 4>> :
308 &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, false, detail::dummy<float, 4>>;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000309 break;
310 default:
311 ARM_COMPUTE_ERROR("Element size not supported");
312 break;
313 }
314}
315
316void NEBatchNormalizationLayerKernel::configure_fused()
317{
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000318 // NCHW Fused Batched Normalization with activation functions : FP32
319 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nchw =
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000320 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100321 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::relu<float, 4>> },
322 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::brelu<float, 4>> },
323 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::lubrelu<float, 4>> }
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000324 };
325 // NHWC Fused Batched Normalization with activation functions : FP32
326 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nhwc =
327 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100328 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, true, detail::relu<float, 4>> },
329 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, true, detail::brelu<float, 4>> },
330 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, true, detail::lubrelu<float, 4>> }
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000331 };
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100332#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
333 // NCHW Fused Batched Normalization with activation functions : FP16
334 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nchw =
335 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100336 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::relu<float16_t, 8>> },
337 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::brelu<float16_t, 8>> },
338 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::lubrelu<float16_t, 8>> }
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100339 };
340 // NHWC Fused Batched Normalization with activation functions : FP16
341 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nhwc =
342 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100343 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, true, detail::relu<float16_t, 8>> },
344 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, true, detail::brelu<float16_t, 8>> },
345 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, true, detail::lubrelu<float16_t, 8>> }
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100346 };
347#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000348
349 switch(_input->info()->data_type())
350 {
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100351#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
352 case DataType::F16:
353 _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f16_nhwc[_act_info.activation()] : bn_fused_map_f16_nchw[_act_info.activation()];
354 break;
355#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000356 case DataType::F32:
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000357 _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f32_nhwc[_act_info.activation()] : bn_fused_map_f32_nchw[_act_info.activation()];
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000358 break;
359 default:
360 ARM_COMPUTE_ERROR("Element size not supported");
361 break;
362 }
363}
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000364
365NEBatchNormalizationLayerKernel::NEBatchNormalizationLayerKernel()
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000366 : _func(nullptr), _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _gamma(nullptr), _beta(nullptr), _epsilon(), _act_info()
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000367{
368}
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100369
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000370void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output,
371 const ITensor *mean, const ITensor *var,
372 const ITensor *beta, const ITensor *gamma,
373 float epsilon, ActivationLayerInfo act_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100374{
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000375 ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000376
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000377 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr,
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000378 mean->info(), var->info(),
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000379 (beta != nullptr) ? beta->info() : nullptr,
380 (gamma != nullptr) ? gamma->info() : nullptr,
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000381 epsilon, act_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100382
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000383 _input = input;
384 _output = input;
385 _mean = mean;
386 _var = var;
387 _gamma = gamma;
388 _beta = beta;
389 _epsilon = epsilon;
390 _act_info = act_info;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100391
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000392 const bool run_in_place = (output == nullptr) || (output == input);
393 if(!run_in_place)
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100394 {
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100395 _output = output;
396 }
397
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000398 // Configure activation function to run
399 if(_act_info.enabled())
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100400 {
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000401 configure_fused();
402 }
403 else
404 {
405 configure_non_fused();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100406 }
407
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000408 // Configure kernel window
Giorgio Arena47463262019-08-05 17:15:40 +0100409 auto win_config = validate_and_configure_window(input->info(), (run_in_place) ? nullptr : output->info(), mean->info(), var->info(), (gamma != nullptr) ? gamma->info() : nullptr,
410 (beta != nullptr) ? beta->info() : nullptr);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000411 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
412 INEKernel::configure(win_config.second);
413}
414
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000415Status NEBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output,
416 const ITensorInfo *mean, const ITensorInfo *var,
417 const ITensorInfo *beta, const ITensorInfo *gamma,
418 float epsilon, ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000419{
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000420 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon, act_info));
Giorgio Arena47463262019-08-05 17:15:40 +0100421 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output ? output->clone().get() : nullptr, mean->clone().get(), var->clone().get(),
422 (gamma != nullptr) ? gamma->clone().get() : nullptr, (beta != nullptr) ? beta->clone().get() : nullptr)
423 .first);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000424
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000425 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100426}
427
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100428void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100429{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100430 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100431 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
432 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
433 ARM_COMPUTE_ERROR_ON(_func == nullptr);
434
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000435 (this->*_func)(window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100436}
Georgios Pinitas980a9162020-06-03 20:16:46 +0100437} // namespace arm_compute