blob: afb08e5d1c3bf781d1c78a3cb92b1079c4f8cc35 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Utils.h"
29#include "arm_compute/core/Validate.h"
30#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010031#include "src/core/CPP/Validate.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010032#include "src/core/NEON/NEFixedPoint.h"
33#include "src/core/NEON/NEMath.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/helpers/AutoConfiguration.h"
35#include "src/core/helpers/WindowHelpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010037#include "src/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
38#include "src/core/NEON/wrapper/wrapper.h"
Georgios Pinitas980a9162020-06-03 20:16:46 +010039
Georgios Pinitas57c033b2018-02-15 12:29:44 +000040#include <map>
41
Georgios Pinitas980a9162020-06-03 20:16:46 +010042namespace arm_compute
43{
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000044namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045{
Georgios Pinitas57c033b2018-02-15 12:29:44 +000046Status
47validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *mean, const ITensorInfo *var,
48 const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon, ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000049{
50 ARM_COMPUTE_UNUSED(epsilon);
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Georgios Pinitasaaba4c62018-08-22 16:20:21 +010052 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Georgios Pinitas57c033b2018-02-15 12:29:44 +000053
54 if(act_info.enabled())
55 {
56 ActivationLayerInfo::ActivationFunction act = act_info.activation();
Georgios Pinitas6f109bd2018-07-16 12:57:42 +010057 ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU
58 && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
Georgios Pinitas57c033b2018-02-15 12:29:44 +000059 && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
60 ARM_COMPUTE_RETURN_ERROR_ON(act_info.b() > act_info.a());
61 }
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000062
63 if(nullptr != output)
64 {
65 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +000066 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000067 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000068 }
69
Michele Di Giorgio4d336302018-03-02 09:43:54 +000070 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
Michele Di Giorgio4d336302018-03-02 09:43:54 +000071 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
72 if(beta != nullptr)
73 {
74 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
Michele Di Giorgio4d336302018-03-02 09:43:54 +000075 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
76 }
77 if(gamma != nullptr)
78 {
79 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
Michele Di Giorgio4d336302018-03-02 09:43:54 +000080 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
81 }
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +000082 ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000083
Georgios Pinitas631c41a2017-12-06 11:53:03 +000084 return Status{};
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000085}
86
Giorgio Arena47463262019-08-05 17:15:40 +010087std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *mean, ITensorInfo *var, ITensorInfo *gamma, ITensorInfo *beta)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000088{
Georgios Pinitas980a9162020-06-03 20:16:46 +010089 ARM_COMPUTE_UNUSED(mean, var, gamma, beta);
90
91 // Configure kernel window
92 Window win = calculate_max_window(*input, Steps());
93
Michele Di Giorgio4d336302018-03-02 09:43:54 +000094 if(output != nullptr)
95 {
Georgios Pinitas980a9162020-06-03 20:16:46 +010096 // Output auto initialization if not yet initialized
Michele Di Giorgio4d336302018-03-02 09:43:54 +000097 auto_init_if_empty(*output, *input->clone());
Georgios Pinitas980a9162020-06-03 20:16:46 +010098
99 // NEBatchNormalizationLayerKernel doesn't need padding so update_window_and_padding() can be skipped
100 Coordinates coord;
101 coord.set_num_dimensions(output->num_dimensions());
102 output->set_valid_region(ValidRegion(coord, output->tensor_shape()));
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000103 }
104
Georgios Pinitas980a9162020-06-03 20:16:46 +0100105 return std::make_pair(Status{}, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100106}
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000107} //namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100108
Georgios Pinitas980a9162020-06-03 20:16:46 +0100109template <typename T, bool fused_activation, typename F>
110void NEBatchNormalizationLayerKernel::batch_normalization_nchw(const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111{
Georgios Pinitas980a9162020-06-03 20:16:46 +0100112 /** NEON vector tag type. */
113 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
114
115 const int window_step_x = 16 / sizeof(T);
116 const auto window_start_x = static_cast<int>(window.x().start());
117 const auto window_end_x = static_cast<int>(window.x().end());
118
119 Window win_to_use = window;
120 win_to_use.set(Window::DimX, Window::Dimension(0, 1, 1));
121
122 Iterator input(_input, win_to_use);
123 Iterator output(_output, win_to_use);
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100124
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100125 F activation_functor(_act_info);
126
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100127 // Hold information about the current feature map we are iterating.
128 // Only compute denominator and NEON vectors once per feature map.
129 int slice = -1;
130
Georgios Pinitas980a9162020-06-03 20:16:46 +0100131 const auto input_mean = reinterpret_cast<const T *>(_mean->ptr_to_element(Coordinates(0, 0)));
132 const auto input_var = reinterpret_cast<const T *>(_var->ptr_to_element(Coordinates(0, 0)));
133 const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const T *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
134 const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const T *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100135
Georgios Pinitas980a9162020-06-03 20:16:46 +0100136 T mean = static_cast<T>(0);
137 T var = static_cast<T>(0);
138 T gamma = static_cast<T>(1);
139 T beta = static_cast<T>(0);
140 T denominator = static_cast<T>(0);
141
142 auto mean_vec = wrapper::vdup_n(mean, ExactTagType{});
143 auto var_vec = wrapper::vdup_n(var, ExactTagType{});
144 auto gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
145 auto beta_vec = wrapper::vdup_n(beta, ExactTagType{});
146 auto denominator_vec = wrapper::vdup_n(denominator, ExactTagType{});
147 const auto epsilon_vec = wrapper::vdup_n(static_cast<T>(_epsilon), ExactTagType{});
148 execute_window_loop(win_to_use, [&](const Coordinates & id)
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100149 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100150 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
151 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
152
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100153 if(slice != id.z())
154 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100155 mean = input_mean[id.z()];
156 var = input_var[id.z()];
157 mean_vec = wrapper::vdup_n(mean, ExactTagType{});
158 var_vec = wrapper::vdup_n(var, ExactTagType{});
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000159 if(input_gamma != nullptr)
160 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100161 gamma = input_gamma[id.z()];
162 gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000163 }
164 if(input_beta != nullptr)
165 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100166 beta = input_beta[id.z()];
167 beta_vec = wrapper::vdup_n(beta, ExactTagType{});
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000168 }
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100169
170 // Calculate denominator
Georgios Pinitas980a9162020-06-03 20:16:46 +0100171 denominator_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
172 denominator = wrapper::vgetlane(denominator_vec, 0);
173 slice = id.z();
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100174 }
175
Georgios Pinitas980a9162020-06-03 20:16:46 +0100176 // Perform core calculations using vector operations
177 int x = window_start_x;
178 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100179 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100180 // Calculate x bar
181 const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec);
182 const auto x_bar = wrapper::vmul(numerator, denominator_vec);
183 auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec);
184
185 // Perform fused activation
186 if(fused_activation)
187 {
188 activation_functor(res);
189 }
190
191 // Store results
192 wrapper::vstore(output_ptr + x, res);
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100193 }
194
Georgios Pinitas980a9162020-06-03 20:16:46 +0100195 // Compute left-over elements
196 for(; x < window_end_x; ++x)
197 {
198 const T numerator = input_ptr[x] - mean;
199 const T x_bar = numerator * denominator;
200 T res = beta + x_bar * gamma;
201
202 // Perform fused activation
203 if(fused_activation)
204 {
205 activation_functor(res);
206 }
207
208 // Store results
209 *(output_ptr + x) = res;
210 }
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100211 },
212 input, output);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000213}
214
Georgios Pinitas980a9162020-06-03 20:16:46 +0100215template <typename T, bool fused_activation, typename F>
216void NEBatchNormalizationLayerKernel::batch_normalization_nhwc(const Window &window)
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000217{
Georgios Pinitas980a9162020-06-03 20:16:46 +0100218 /** NEON vector tag type. */
219 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
220
221 const int window_step_x = 16 / sizeof(T);
222 const auto window_start_x = static_cast<int>(window.x().start());
223 const auto window_end_x = static_cast<int>(window.x().end());
224
225 Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
226 win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
227
228 Iterator input(_input, win_collapsed);
229 Iterator output(_output, win_collapsed);
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000230
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100231 F activation_functor(_act_info);
232
Georgios Pinitas980a9162020-06-03 20:16:46 +0100233 const auto input_mean = reinterpret_cast<const T *>(_mean->ptr_to_element(Coordinates(0, 0)));
234 const auto input_var = reinterpret_cast<const T *>(_var->ptr_to_element(Coordinates(0, 0)));
235 const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const T *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
236 const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const T *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000237
Georgios Pinitas980a9162020-06-03 20:16:46 +0100238 const auto epsilon_vec = wrapper::vdup_n(static_cast<T>(_epsilon), ExactTagType{});
239 execute_window_loop(win_collapsed, [&](const Coordinates &)
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000240 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100241 const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
242 const auto output_ptr = reinterpret_cast<T *>(output.ptr());
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000243
Georgios Pinitas980a9162020-06-03 20:16:46 +0100244 // Perform core calculations using vector operations
245 int x = window_start_x;
246 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000247 {
248 // Conctruct vectors
Georgios Pinitas980a9162020-06-03 20:16:46 +0100249 const auto mean_vec = wrapper::vloadq(input_mean + x);
250 const auto var_vec = wrapper::vloadq(input_var + x);
251 const auto gamma_vec = (input_gamma != nullptr) ? wrapper::vloadq(input_gamma + x) : wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
252 const auto beta_vec = (input_beta != nullptr) ? wrapper::vloadq(input_beta + x) : wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000253
254 // Calculate denominator
Georgios Pinitas980a9162020-06-03 20:16:46 +0100255 const auto denominator = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
256
257 // Calculate x bar
258 const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec);
259 const auto x_bar = wrapper::vmul(numerator, denominator);
260 auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec);
261
262 // Perform fused activation
263 if(fused_activation)
264 {
265 activation_functor(res);
266 }
267
268 // Store results
269 wrapper::vstore(output_ptr + x, res);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000270 }
271
Georgios Pinitas980a9162020-06-03 20:16:46 +0100272 // Compute left-over elements
273 for(; x < window_end_x; ++x)
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000274 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100275 // Conctruct vectors
276 const T gamma = (input_gamma != nullptr) ? input_gamma[x] : 1.f;
277 const T beta = (input_beta != nullptr) ? input_beta[x] : 0.f;
278
279 const T denominator = sqrt(input_var[x] + _epsilon);
280 const T numerator = input_ptr[x] - input_mean[x];
281 const T x_bar = numerator / denominator;
282 T res = beta + x_bar * gamma;
283
284 // Perform fused activation
285 if(fused_activation)
286 {
287 activation_functor(res);
288 }
289
290 // Store results
291 *reinterpret_cast<T *>(output_ptr + x) = res;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000292 }
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000293 },
294 input, output);
295}
296
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000297void NEBatchNormalizationLayerKernel::configure_non_fused()
298{
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000299 const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000300 switch(_input->info()->data_type())
301 {
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100302#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000303 case DataType::F16:
Georgios Pinitas980a9162020-06-03 20:16:46 +0100304 _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, false, detail::dummy<float16_t, 8>> :
305 &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, false, detail::dummy<float16_t, 8>>;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000306 break;
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100307#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000308 case DataType::F32:
Georgios Pinitas980a9162020-06-03 20:16:46 +0100309 _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, false, detail::dummy<float, 4>> :
310 &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, false, detail::dummy<float, 4>>;
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000311 break;
312 default:
313 ARM_COMPUTE_ERROR("Element size not supported");
314 break;
315 }
316}
317
318void NEBatchNormalizationLayerKernel::configure_fused()
319{
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000320 // NCHW Fused Batched Normalization with activation functions : FP32
321 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nchw =
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000322 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100323 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::relu<float, 4>> },
324 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::brelu<float, 4>> },
325 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::lubrelu<float, 4>> }
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000326 };
327 // NHWC Fused Batched Normalization with activation functions : FP32
328 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nhwc =
329 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100330 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, true, detail::relu<float, 4>> },
331 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, true, detail::brelu<float, 4>> },
332 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float, true, detail::lubrelu<float, 4>> }
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000333 };
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100334#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
335 // NCHW Fused Batched Normalization with activation functions : FP16
336 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nchw =
337 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100338 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::relu<float16_t, 8>> },
339 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::brelu<float16_t, 8>> },
340 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::lubrelu<float16_t, 8>> }
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100341 };
342 // NHWC Fused Batched Normalization with activation functions : FP16
343 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nhwc =
344 {
Georgios Pinitas980a9162020-06-03 20:16:46 +0100345 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, true, detail::relu<float16_t, 8>> },
346 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, true, detail::brelu<float16_t, 8>> },
347 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_nhwc<float16_t, true, detail::lubrelu<float16_t, 8>> }
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100348 };
349#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000350
351 switch(_input->info()->data_type())
352 {
Georgios Pinitasaaba4c62018-08-22 16:20:21 +0100353#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
354 case DataType::F16:
355 _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f16_nhwc[_act_info.activation()] : bn_fused_map_f16_nchw[_act_info.activation()];
356 break;
357#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000358 case DataType::F32:
Michele Di Giorgio0cbb9272018-03-01 16:56:48 +0000359 _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f32_nhwc[_act_info.activation()] : bn_fused_map_f32_nchw[_act_info.activation()];
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000360 break;
361 default:
362 ARM_COMPUTE_ERROR("Element size not supported");
363 break;
364 }
365}
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000366
367NEBatchNormalizationLayerKernel::NEBatchNormalizationLayerKernel()
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000368 : _func(nullptr), _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _gamma(nullptr), _beta(nullptr), _epsilon(), _act_info()
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000369{
370}
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100371
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000372void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output,
373 const ITensor *mean, const ITensor *var,
374 const ITensor *beta, const ITensor *gamma,
375 float epsilon, ActivationLayerInfo act_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100376{
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000377 ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000378
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000379 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr,
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000380 mean->info(), var->info(),
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000381 (beta != nullptr) ? beta->info() : nullptr,
382 (gamma != nullptr) ? gamma->info() : nullptr,
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000383 epsilon, act_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100384
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000385 _input = input;
386 _output = input;
387 _mean = mean;
388 _var = var;
389 _gamma = gamma;
390 _beta = beta;
391 _epsilon = epsilon;
392 _act_info = act_info;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100393
Michele Di Giorgio4d336302018-03-02 09:43:54 +0000394 const bool run_in_place = (output == nullptr) || (output == input);
395 if(!run_in_place)
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100396 {
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100397 _output = output;
398 }
399
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000400 // Configure activation function to run
401 if(_act_info.enabled())
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100402 {
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000403 configure_fused();
404 }
405 else
406 {
407 configure_non_fused();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100408 }
409
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000410 // Configure kernel window
Giorgio Arena47463262019-08-05 17:15:40 +0100411 auto win_config = validate_and_configure_window(input->info(), (run_in_place) ? nullptr : output->info(), mean->info(), var->info(), (gamma != nullptr) ? gamma->info() : nullptr,
412 (beta != nullptr) ? beta->info() : nullptr);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000413 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
414 INEKernel::configure(win_config.second);
415}
416
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000417Status NEBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output,
418 const ITensorInfo *mean, const ITensorInfo *var,
419 const ITensorInfo *beta, const ITensorInfo *gamma,
420 float epsilon, ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000421{
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000422 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon, act_info));
Giorgio Arena47463262019-08-05 17:15:40 +0100423 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output ? output->clone().get() : nullptr, mean->clone().get(), var->clone().get(),
424 (gamma != nullptr) ? gamma->clone().get() : nullptr, (beta != nullptr) ? beta->clone().get() : nullptr)
425 .first);
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000426
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000427 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100428}
429
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100430void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100431{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100432 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100433 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
434 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
435 ARM_COMPUTE_ERROR_ON(_func == nullptr);
436
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000437 (this->*_func)(window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100438}
Georgios Pinitas980a9162020-06-03 20:16:46 +0100439} // namespace arm_compute