COMPMID-2336 Fix valgrind error for BatchNormalizationLayer on NEON with NHWC layout

Change-Id: I9ed2d0647ae3c33bce6290acfdac356ffffcb709
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1697
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
index 683d48b..6bd30ee 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -80,7 +80,7 @@
     return Status{};
 }
 
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *mean, ITensorInfo *var, ITensorInfo *gamma, ITensorInfo *beta)
 {
     if(output != nullptr)
     {
@@ -92,9 +92,34 @@
 
     Window                 win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
     AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
-    AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
-    bool                   window_changed = update_window_and_padding(win, input_access, output_access);
-    output_access.set_valid_region(win, input->valid_region());
+    bool                   window_changed = update_window_and_padding(win, input_access);
+
+    if(output != nullptr)
+    {
+        AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+        window_changed |= update_window_and_padding(win, output_access);
+        output_access.set_valid_region(win, input->valid_region());
+    }
+
+    // Mean, var, gamma and beta get parallelized for the NHWC case as they follow the channel dimension, which is along the first axis
+    if(input->data_layout() == DataLayout::NHWC)
+    {
+        AccessWindowHorizontal mean_access(mean, 0, num_elems_processed_per_iteration);
+        AccessWindowHorizontal var_access(var, 0, num_elems_processed_per_iteration);
+        window_changed |= update_window_and_padding(win, mean_access, var_access);
+
+        if(gamma != nullptr)
+        {
+            AccessWindowHorizontal gamma_access(gamma, 0, num_elems_processed_per_iteration);
+            window_changed |= update_window_and_padding(win, gamma_access);
+        }
+        if(beta != nullptr)
+        {
+            AccessWindowHorizontal beta_access(beta, 0, num_elems_processed_per_iteration);
+            window_changed |= update_window_and_padding(win, beta_access);
+        }
+    }
+
     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
     return std::make_pair(err, win);
 }
@@ -422,7 +447,8 @@
     }
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window(input->info(), (run_in_place) ? nullptr : output->info());
+    auto win_config = validate_and_configure_window(input->info(), (run_in_place) ? nullptr : output->info(), mean->info(), var->info(), (gamma != nullptr) ? gamma->info() : nullptr,
+                                                    (beta != nullptr) ? beta->info() : nullptr);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     INEKernel::configure(win_config.second);
 }
@@ -433,7 +459,9 @@
                                                  float epsilon, ActivationLayerInfo act_info)
 {
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon, act_info));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output ? output->clone().get() : nullptr).first);
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output ? output->clone().get() : nullptr, mean->clone().get(), var->clone().get(),
+                                                              (gamma != nullptr) ? gamma->clone().get() : nullptr, (beta != nullptr) ? beta->clone().get() : nullptr)
+                                .first);
 
     return Status{};
 }