blob: 1f730a2c3c901c3a93f69f2d2a76afd8cdbc061d [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitas57c033b2018-02-15 12:29:44 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h"
25
26#include "arm_compute/core/Helpers.h"
27#include "arm_compute/core/NEON/NEFixedPoint.h"
28#include "arm_compute/core/NEON/NEMath.h"
Georgios Pinitas57c033b2018-02-15 12:29:44 +000029#include "arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/core/Window.h"
34
Georgios Pinitas57c033b2018-02-15 12:29:44 +000035#include <map>
36
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037using namespace arm_compute;
38
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000039namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040{
Georgios Pinitas57c033b2018-02-15 12:29:44 +000041Status
42validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *mean, const ITensorInfo *var,
43 const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon, ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000044{
45 ARM_COMPUTE_UNUSED(epsilon);
Georgios Pinitas57c033b2018-02-15 12:29:44 +000046 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16,
47 DataType::F32);
48
49 if(act_info.enabled())
50 {
51 ActivationLayerInfo::ActivationFunction act = act_info.activation();
52 ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
53 ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
54 && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
55 ARM_COMPUTE_RETURN_ERROR_ON(act_info.b() > act_info.a());
56 }
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000057
58 if(nullptr != output)
59 {
60 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
61 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
62 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
63 }
64
65 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma);
66 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma);
67 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma);
68 ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0));
69
Georgios Pinitas631c41a2017-12-06 11:53:03 +000070 return Status{};
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000071}
72
Georgios Pinitas631c41a2017-12-06 11:53:03 +000073std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000074{
75 unsigned int num_elems_processed_per_iteration = 16 / input->element_size();
76
77 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
78 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
79 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
80 bool window_changed = update_window_and_padding(win, input_access, output_access);
81 output_access.set_valid_region(win, input->valid_region());
Georgios Pinitas631c41a2017-12-06 11:53:03 +000082 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +000083 return std::make_pair(err, win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010084}
Georgios Pinitas57c033b2018-02-15 12:29:44 +000085} //namespace
Anthony Barbier6ff3b192017-09-04 18:44:23 +010086
Georgios Pinitas57c033b2018-02-15 12:29:44 +000087template <bool fused_activation>
88void NEBatchNormalizationLayerKernel::batch_normalization_qs8(const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010089{
Georgios Pinitas57c033b2018-02-15 12:29:44 +000090 static_assert(!fused_activation, "Activation is not supported for QS8");
91
92 Iterator input(_input, window);
93 Iterator output(_output, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010094
95 // Hold information about the current feature map we are iterating.
96 // Only compute denominator and NEON vectors once per feature map.
97 int slice = -1;
98
Georgios Pinitas57c033b2018-02-15 12:29:44 +000099 const int fixed_point_position = _input->info()->fixed_point_position();
100 const auto input_mean = reinterpret_cast<const qint8_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
101 const auto input_var = reinterpret_cast<const qint8_t *>(_var->ptr_to_element(Coordinates(0, 0)));
102 const auto input_gamma = reinterpret_cast<const qint8_t *>(_gamma->ptr_to_element(Coordinates(0, 0)));
103 const auto input_beta = reinterpret_cast<const qint8_t *>(_beta->ptr_to_element(Coordinates(0, 0)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100104
105 qint8x16_t mean_vec = vdupq_n_qs8(0);
106 qint8x16_t var_vec = vdupq_n_qs8(0);
107 qint8x16_t gamma_vec = vdupq_n_qs8(0);
108 qint8x16_t beta_vec = vdupq_n_qs8(0);
109 qint8x16_t denominator = vdupq_n_qs8(0);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000110 const qint8x16_t epsilon_vec = vdupq_n_qs8(sqcvt_qs8_f32(_epsilon, fixed_point_position));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111 execute_window_loop(window, [&](const Coordinates & id)
112 {
113 if(slice != id.z())
114 {
115 // Conctruct vectors
116 mean_vec = vdupq_n_qs8(*(input_mean + id.z()));
117 var_vec = vdupq_n_qs8(*(input_var + id.z()));
118 gamma_vec = vdupq_n_qs8(*(input_gamma + id.z()));
119 beta_vec = vdupq_n_qs8(*(input_beta + id.z()));
120
121 // Calculate denominator
122 denominator = vqinvsqrtq_qs8(vqaddq_qs8(var_vec, epsilon_vec), fixed_point_position);
123 slice = id.z();
124 }
125
126 // Calculate x bar and store results
127 const qint8x16_t numerator = vqsubq_qs8(vld1q_qs8(reinterpret_cast<const qint8_t *>(input.ptr())), mean_vec);
128 const qint8x16_t x_bar = vqmulq_qs8(numerator, denominator, fixed_point_position);
129 vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vqmlaq_qs8(beta_vec, x_bar, gamma_vec, fixed_point_position));
130 },
131 input, output);
132}
133
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000134template <bool fused_activation>
135void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &window)
Michalis Spyroubbd3d602017-06-21 17:29:40 +0100136{
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000137 static_assert(!fused_activation, "Activation is not supported for QS16");
138
139 Iterator input(_input, window);
140 Iterator output(_output, window);
Michalis Spyroubbd3d602017-06-21 17:29:40 +0100141
142 // Hold information about the current feature map we are iterating.
143 // Only compute denominator and NEON vectors once per feature map.
144 int slice = -1;
145
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000146 const int fixed_point_position = _input->info()->fixed_point_position();
147 const auto input_mean = reinterpret_cast<const qint16_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
148 const auto input_var = reinterpret_cast<const qint16_t *>(_var->ptr_to_element(Coordinates(0, 0)));
149 const auto input_gamma = reinterpret_cast<const qint16_t *>(_gamma->ptr_to_element(Coordinates(0, 0)));
150 const auto input_beta = reinterpret_cast<const qint16_t *>(_beta->ptr_to_element(Coordinates(0, 0)));
Michalis Spyroubbd3d602017-06-21 17:29:40 +0100151
152 qint16x8_t mean_vec = vdupq_n_qs16(0);
153 qint16x8_t var_vec = vdupq_n_qs16(0);
154 qint16x8_t gamma_vec = vdupq_n_qs16(0);
155 qint16x8_t beta_vec = vdupq_n_qs16(0);
156 qint16x8_t denominator = vdupq_n_qs16(0);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000157 const qint16x8_t epsilon_vec = vdupq_n_qs16(sqcvt_qs16_f32(_epsilon, fixed_point_position));
Michalis Spyroubbd3d602017-06-21 17:29:40 +0100158 execute_window_loop(window, [&](const Coordinates & id)
159 {
160 if(slice != id.z())
161 {
162 // Conctruct vectors
163 mean_vec = vdupq_n_qs16(*(input_mean + id.z()));
164 var_vec = vdupq_n_qs16(*(input_var + id.z()));
165 gamma_vec = vdupq_n_qs16(*(input_gamma + id.z()));
166 beta_vec = vdupq_n_qs16(*(input_beta + id.z()));
167
168 // Calculate denominator
169 denominator = vqinvsqrtq_qs16(vqaddq_qs16(var_vec, epsilon_vec), fixed_point_position);
170 slice = id.z();
171 }
172
173 // Calculate x bar and store results
174 const qint16x8_t numerator = vqsubq_qs16(vld1q_qs16(reinterpret_cast<const qint16_t *>(input.ptr())), mean_vec);
175 const qint16x8_t x_bar = vqmulq_qs16(numerator, denominator, fixed_point_position);
176 vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), vqmlaq_qs16(beta_vec, x_bar, gamma_vec, fixed_point_position));
177 },
178 input, output);
179}
180
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000181template <bool fused_activation>
182void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100183{
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000184 static_assert(!fused_activation, "Activation is not supported for QS8");
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100185
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000186 ARM_COMPUTE_UNUSED(window);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000187#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000188 Iterator input(_input, window);
189 Iterator output(_output, window);
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100190
191 // Hold information about the current feature map we are iterating.
192 // Only compute denominator and NEON vectors once per feature map.
193 int slice = -1;
194
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000195 const auto input_mean = reinterpret_cast<const float16_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
196 const auto input_var = reinterpret_cast<const float16_t *>(_var->ptr_to_element(Coordinates(0, 0)));
197 const auto input_gamma = reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0)));
198 const auto input_beta = reinterpret_cast<const float16_t *>(_beta->ptr_to_element(Coordinates(0, 0)));
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100199
200 float16x8_t mean_vec = vdupq_n_f16(0.0);
201 float16x8_t var_vec = vdupq_n_f16(0.0);
202 float16x8_t gamma_vec = vdupq_n_f16(0.0);
203 float16x8_t beta_vec = vdupq_n_f16(0.0);
204 float16x8_t denominator = vdupq_n_f16(0.0);
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000205 const float16x8_t epsilon_vec = vdupq_n_f16(_epsilon);
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100206 execute_window_loop(window, [&](const Coordinates & id)
207 {
208 if(slice != id.z())
209 {
210 // Conctruct vectors
211 mean_vec = vdupq_n_f16(*(input_mean + id.z()));
212 var_vec = vdupq_n_f16(*(input_var + id.z()));
213 gamma_vec = vdupq_n_f16(*(input_gamma + id.z()));
214 beta_vec = vdupq_n_f16(*(input_beta + id.z()));
215
216 // Calculate denominator
217 denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec));
218 slice = id.z();
219 }
220
221 // Calculate x bar and store results
222 const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), mean_vec);
223 const float16x8_t x_bar = vmulq_f16(numerator, denominator);
224 vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)));
225 },
226 input, output);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000227#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000228}
229
230template <bool fused_activation, typename F>
231void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &window)
232{
233 Iterator input(_input, window);
234 Iterator output(_output, window);
235
236 F activation_functor(_act_info);
237
238 // Hold information about the current feature map we are iterating.
239 // Only compute denominator and NEON vectors once per feature map.
240 int slice = -1;
241
242 const auto input_mean = reinterpret_cast<const float *>(_mean->ptr_to_element(Coordinates(0, 0)));
243 const auto input_var = reinterpret_cast<const float *>(_var->ptr_to_element(Coordinates(0, 0)));
244 const auto input_gamma = reinterpret_cast<const float *>(_gamma->ptr_to_element(Coordinates(0, 0)));
245 const auto input_beta = reinterpret_cast<const float *>(_beta->ptr_to_element(Coordinates(0, 0)));
246
247 float32x4_t mean_vec = vdupq_n_f32(0.0);
248 float32x4_t var_vec = vdupq_n_f32(0.0);
249 float32x4_t gamma_vec = vdupq_n_f32(0.0);
250 float32x4_t beta_vec = vdupq_n_f32(0.0);
251 float32x4_t denominator = vdupq_n_f32(0.0);
252 const float32x4_t epsilon_vec = vdupq_n_f32(_epsilon);
253 execute_window_loop(window, [&](const Coordinates & id)
254 {
255 if(slice != id.z())
256 {
257 // Conctruct vectors
258 mean_vec = vdupq_n_f32(*(input_mean + id.z()));
259 var_vec = vdupq_n_f32(*(input_var + id.z()));
260 gamma_vec = vdupq_n_f32(*(input_gamma + id.z()));
261 beta_vec = vdupq_n_f32(*(input_beta + id.z()));
262
263 // Calculate denominator
264 denominator = vinvsqrtq_f32(vaddq_f32(var_vec, epsilon_vec));
265 slice = id.z();
266 }
267
268 // Calculate x bar
269 const float32x4_t numerator = vsubq_f32(vld1q_f32(reinterpret_cast<const float *>(input.ptr())), mean_vec);
270 const float32x4_t x_bar = vmulq_f32(numerator, denominator);
271 float32x4_t res = vmlaq_f32(beta_vec, x_bar, gamma_vec);
272
273 // Perform fused activation
274 if(fused_activation)
275 {
276 activation_functor(res);
277 }
278
279 // Store results
280 vst1q_f32(reinterpret_cast<float *>(output.ptr()), res);
281 },
282 input, output);
283}
284
285void NEBatchNormalizationLayerKernel::configure_non_fused()
286{
287 switch(_input->info()->data_type())
288 {
289 case DataType::QS8:
290 _func = &NEBatchNormalizationLayerKernel::batch_normalization_qs8<false>;
291 break;
292 case DataType::QS16:
293 _func = &NEBatchNormalizationLayerKernel::batch_normalization_qs16<false>;
294 break;
295 case DataType::F16:
296 _func = &NEBatchNormalizationLayerKernel::batch_normalization_fp16<false>;
297 break;
298 case DataType::F32:
299 _func = &NEBatchNormalizationLayerKernel::batch_normalization_fp32<false, ::detail::dummy<float, 4>>;
300 break;
301 default:
302 ARM_COMPUTE_ERROR("Element size not supported");
303 break;
304 }
305}
306
307void NEBatchNormalizationLayerKernel::configure_fused()
308{
309 // Fused Batched Normalization with activation functions : FP32
310 static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32 =
311 {
312 { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32<true, ::detail::relu<float, 4>> },
313 { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32<true, ::detail::brelu<float, 4>> },
314 { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32<true, ::detail::lubrelu<float, 4>> }
315 };
316
317 switch(_input->info()->data_type())
318 {
319 case DataType::F32:
320 _func = bn_fused_map_f32[_act_info.activation()];
321 break;
322 default:
323 ARM_COMPUTE_ERROR("Element size not supported");
324 break;
325 }
326}
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000327
328NEBatchNormalizationLayerKernel::NEBatchNormalizationLayerKernel()
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000329 : _func(nullptr), _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _gamma(nullptr), _beta(nullptr), _epsilon(), _act_info()
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000330{
331}
Pablo Tello8fda1cb2017-07-05 15:20:38 +0100332
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000333void NEBatchNormalizationLayerKernel::configure(ITensor *input, ITensor *output,
334 const ITensor *mean, const ITensor *var,
335 const ITensor *beta, const ITensor *gamma,
336 float epsilon, ActivationLayerInfo act_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100337{
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000338 ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var, beta, gamma);
339
340 ITensorInfo *output_info = nullptr;
341
342 if(nullptr != output)
343 {
344 // Output tensor auto initialization if not yet initialized
345 auto_init_if_empty(*output->info(), *input->info());
346
347 output_info = output->info();
348 }
349
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000350 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output_info,
351 mean->info(), var->info(),
352 beta->info(), gamma->info(),
353 epsilon, act_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100354
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000355 _input = input;
356 _output = input;
357 _mean = mean;
358 _var = var;
359 _gamma = gamma;
360 _beta = beta;
361 _epsilon = epsilon;
362 _act_info = act_info;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100363
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100364 if(output != nullptr)
365 {
Georgios Pinitas409ee0a2017-08-18 10:16:09 +0100366 _output = output;
367 }
368
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000369 // Configure activation function to run
370 if(_act_info.enabled())
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100371 {
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000372 configure_fused();
373 }
374 else
375 {
376 configure_non_fused();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100377 }
378
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000379 // Configure kernel window
380 auto win_config = validate_and_configure_window(input->info(), output_info);
381 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
382 INEKernel::configure(win_config.second);
383}
384
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000385Status NEBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output,
386 const ITensorInfo *mean, const ITensorInfo *var,
387 const ITensorInfo *beta, const ITensorInfo *gamma,
388 float epsilon, ActivationLayerInfo act_info)
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000389{
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000390 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon, act_info));
Ioan-Cristian Szabo303be902017-11-27 16:31:10 +0000391 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output ? output->clone().get() : nullptr).first);
392
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000393 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100394}
395
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100396void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100397{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100398 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100399 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
400 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
401 ARM_COMPUTE_ERROR_ON(_func == nullptr);
402
Georgios Pinitas57c033b2018-02-15 12:29:44 +0000403 (this->*_func)(window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100404}