IVGCVSW-2023 CL and Neon implementation of BatchNorm with NHWC

Change-Id: I962e986607e5d045cd97b9eaeaea2f5ae624db35
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index a5b1d64..a7eca43 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -293,7 +293,7 @@
     {}
 
     float m_Eps;
-    DataLayout m_DataLayout;
+    DataLayoutIndexed m_DataLayout;
 };
 
 struct FakeQuantizationDescriptor
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index db1773a..01c5e9f 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -133,8 +133,20 @@
 
 template <typename BatchNormalizationFloat32Workload, armnn::DataType DataType>
 std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkloadTest(
-    armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
+    armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayoutIndexed dataLayout = DataLayout::NCHW)
 {
+
+    TensorShape tensorShape;
+    switch (dataLayout.GetDataLayout())
+    {
+        case DataLayout::NHWC:
+            tensorShape = { 2, 4, 4, 3 };
+            break;
+        case DataLayout::NCHW:
+        default:
+            tensorShape = { 2, 3, 4, 4 };
+    }
+
     // Creates the layer we're testing.
     BatchNormalizationDescriptor layerDesc;
     layerDesc.m_Eps = 0.05f;
@@ -156,25 +168,23 @@
     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
 
-    TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
-    TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
-
-    // Connects up.
-    Connect(input, layer, TensorInfo(inputShape, DataType));
-    Connect(layer, output, TensorInfo(outputShape, DataType));
+    //Connects up.
+    armnn::TensorInfo tensorInfo(tensorShape, DataType);
+    Connect(input, layer, tensorInfo);
+    Connect(layer, output, tensorInfo);
     CreateTensorHandles(graph, factory);
 
     // Makes the workload and checks it.
     auto workload = MakeAndCheckWorkload<BatchNormalizationFloat32Workload>(*layer, graph, factory);
     BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
     BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f);
-    BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
     BOOST_TEST(queueDescriptor.m_Inputs.size() == 1);
     BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
     BOOST_TEST((queueDescriptor.m_Mean->GetTensorInfo() == TensorInfo({3}, DataType)));
     BOOST_TEST((queueDescriptor.m_Variance->GetTensorInfo() == TensorInfo({3}, DataType)));
     BOOST_TEST((queueDescriptor.m_Gamma->GetTensorInfo() == TensorInfo({3}, DataType)));
     BOOST_TEST((queueDescriptor.m_Beta->GetTensorInfo() == TensorInfo({3}, DataType)));
+    BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout.GetDataLayout() == dataLayout));
 
     // Returns so we can do extra, backend-specific tests.
     return workload;
diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp
index 756b4a6..b5fc031 100644
--- a/src/backends/cl/test/ClCreateWorkloadTests.cpp
+++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp
@@ -144,31 +144,53 @@
 }
 
 template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
-static void ClCreateBatchNormalizationWorkloadTest()
+static void ClCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
 {
     Graph graph;
     ClWorkloadFactory factory;
 
     auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>
-                    (factory, graph);
+                    (factory, graph, dataLayout);
 
     // Checks that inputs/outputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
     BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
     auto inputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]);
     auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
 
-    BOOST_TEST(CompareIClTensorHandleShape(inputHandle, {2, 3, 1, 1}));
-    BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {2, 3, 1, 1}));
+     switch (dataLayout)
+    {
+        case DataLayout::NHWC:
+            BOOST_TEST(CompareIClTensorHandleShape(inputHandle, { 2, 4, 4, 3 }));
+            BOOST_TEST(CompareIClTensorHandleShape(outputHandle, { 2, 4, 4, 3 }));
+            break;
+        default: // NCHW
+            BOOST_TEST(CompareIClTensorHandleShape(inputHandle, { 2, 3, 4, 4 }));
+            BOOST_TEST(CompareIClTensorHandleShape(outputHandle, { 2, 3, 4, 4 }));
+    }
 }
 
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatWorkload)
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatNchwWorkload)
 {
-    ClCreateBatchNormalizationWorkloadTest<ClBatchNormalizationFloatWorkload, armnn::DataType::Float32>();
+    ClCreateBatchNormalizationWorkloadTest<ClBatchNormalizationFloatWorkload,
+                                           armnn::DataType::Float32>(DataLayout::NCHW);
 }
 
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16Workload)
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16NchwWorkload)
 {
-    ClCreateBatchNormalizationWorkloadTest<ClBatchNormalizationFloatWorkload, armnn::DataType::Float16>();
+    ClCreateBatchNormalizationWorkloadTest<ClBatchNormalizationFloatWorkload,
+                                           armnn::DataType::Float16>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatNhwcWorkload)
+{
+    ClCreateBatchNormalizationWorkloadTest<ClBatchNormalizationFloatWorkload,
+                                           armnn::DataType::Float32>(DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationNhwcFloat16NhwcWorkload)
+{
+    ClCreateBatchNormalizationWorkloadTest<ClBatchNormalizationFloatWorkload,
+                                           armnn::DataType::Float16>(DataLayout::NHWC);
 }
 
 BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Workload)
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index 3b1603c..a4f824a 100755
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -181,6 +181,7 @@
 
 // Batch Norm
 ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest)
+ARMNN_AUTO_TEST_CASE(BatchNormNhwc, BatchNormNhwcTest)
 
 // L2 Normalization
 ARMNN_AUTO_TEST_CASE(L2Normalization1d, L2Normalization1dTest)
diff --git a/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp b/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp
index 5bff7a6..24be7cd 100644
--- a/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp
@@ -23,12 +23,20 @@
                                                  const TensorInfo& gamma,
                                                  const BatchNormalizationDescriptor &desc)
 {
-    const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
-    const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
-    const arm_compute::TensorInfo aclMeanInfo = BuildArmComputeTensorInfo(mean);
-    const arm_compute::TensorInfo aclVarInfo = BuildArmComputeTensorInfo(var);
-    const arm_compute::TensorInfo aclBetaInfo = BuildArmComputeTensorInfo(beta);
-    const arm_compute::TensorInfo aclGammaInfo = BuildArmComputeTensorInfo(gamma);
+    const DataLayout dataLayout = desc.m_DataLayout.GetDataLayout();
+
+    const arm_compute::TensorInfo aclInputInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(input, dataLayout);
+    const arm_compute::TensorInfo aclOutputInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(output, dataLayout);
+    const arm_compute::TensorInfo aclMeanInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(mean, dataLayout);
+    const arm_compute::TensorInfo aclVarInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(var, dataLayout);
+    const arm_compute::TensorInfo aclBetaInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(beta, dataLayout);
+    const arm_compute::TensorInfo aclGammaInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(gamma, dataLayout);
 
     return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
                                                             &aclOutputInfo,
@@ -60,6 +68,10 @@
     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
 
+    arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout.GetDataLayout());
+    input.info()->set_data_layout(aclDataLayout);
+    output.info()->set_data_layout(aclDataLayout);
+
     m_Layer.configure(&input,
                       &output,
                       m_Mean.get(),
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
index a588a3e..8d5574c 100644
--- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp
+++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
@@ -153,30 +153,45 @@
 }
 
 template <typename BatchNormalizationWorkloadType, typename armnn::DataType DataType>
-static void NeonCreateBatchNormalizationWorkloadTest()
+static void NeonCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
 {
     Graph                graph;
     NeonWorkloadFactory  factory;
-    auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, graph);
+    auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>
+                    (factory, graph, dataLayout);
 
     // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
     BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
     auto inputHandle  = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Inputs[0]);
     auto outputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Outputs[0]);
-    BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({2, 3, 1, 1}, DataType)));
-    BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({2, 3, 1, 1}, DataType)));
+
+    TensorShape inputShape  = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 4, 4} : TensorShape{2, 4, 4, 3};
+    TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 4, 4} : TensorShape{2, 4, 4, 3};
+
+    BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo(inputShape, DataType)));
+    BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo(outputShape, DataType)));
 }
 
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16Workload)
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16NchwWorkload)
 {
-    NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float16>();
+    NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float16>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16NhwcWorkload)
+{
+    NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float16>(DataLayout::NHWC);
 }
 #endif
 
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatWorkload)
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatNchwWorkload)
 {
-    NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float32>();
+    NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatNhwcWorkload)
+{
+    NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float32>(DataLayout::NHWC);
 }
 
 template <typename armnn::DataType DataType>
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index 31ee7d8..568a236 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -338,6 +338,7 @@
 
 // Batch Norm
 ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest)
+ARMNN_AUTO_TEST_CASE(BatchNormNhwc, BatchNormNhwcTest)
 
 // Constant
 ARMNN_AUTO_TEST_CASE(Constant, ConstantTest)
diff --git a/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp
index f7056a5..95cfdce 100644
--- a/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp
@@ -21,12 +21,20 @@
                                                    const TensorInfo& gamma,
                                                    const BatchNormalizationDescriptor& descriptor)
 {
-    const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
-    const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
-    const arm_compute::TensorInfo aclMeanInfo = BuildArmComputeTensorInfo(mean);
-    const arm_compute::TensorInfo aclVarInfo = BuildArmComputeTensorInfo(var);
-    const arm_compute::TensorInfo aclBetaInfo = BuildArmComputeTensorInfo(beta);
-    const arm_compute::TensorInfo aclGammaInfo = BuildArmComputeTensorInfo(gamma);
+    const DataLayout dataLayout = descriptor.m_DataLayout.GetDataLayout();
+
+    const arm_compute::TensorInfo aclInputInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(input, dataLayout);
+    const arm_compute::TensorInfo aclOutputInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(output, dataLayout);
+    const arm_compute::TensorInfo aclMeanInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(mean, dataLayout);
+    const arm_compute::TensorInfo aclVarInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(var, dataLayout);
+    const arm_compute::TensorInfo aclBetaInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(beta, dataLayout);
+    const arm_compute::TensorInfo aclGammaInfo =
+          armcomputetensorutils::BuildArmComputeTensorInfo(gamma, dataLayout);
 
     return arm_compute::NEBatchNormalizationLayer::validate(&aclInputInfo,
                                                             &aclOutputInfo,
@@ -46,6 +54,10 @@
     arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
     arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
 
+    arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout.GetDataLayout());
+    input.info()->set_data_layout(aclDataLayout);
+    output.info()->set_data_layout(aclDataLayout);
+
     m_Mean = std::make_unique<arm_compute::Tensor>();
     BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
 
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index d258b81..8bad549 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -157,13 +157,13 @@
     switch (dataLayout)
     {
         case DataLayout::NHWC:
-            inputShape  = { 2, 1, 1, 3 };
-            outputShape = { 2, 1, 1, 3 };
+            inputShape  = { 2, 4, 4, 3 };
+            outputShape = { 2, 4, 4, 3 };
             break;
         case DataLayout::NCHW:
         default:
-            inputShape  = { 2, 3, 1, 1 };
-            outputShape = { 2, 3, 1, 1 };
+            inputShape  = { 2, 3, 4, 4 };
+            outputShape = { 2, 3, 4, 4 };
             break;
     }
 
diff --git a/src/backends/test/BatchNormTestImpl.hpp b/src/backends/test/BatchNormTestImpl.hpp
index 4941b00..166c444 100644
--- a/src/backends/test/BatchNormTestImpl.hpp
+++ b/src/backends/test/BatchNormTestImpl.hpp
@@ -95,3 +95,93 @@
 
     return result;
 }
+
+
+template<typename T>
+LayerTestResult<T,4> BatchNormTestNhwcImpl(armnn::IWorkloadFactory& workloadFactory,
+                                           float qScale,
+                                           int32_t qOffset)
+{
+    const unsigned int width    = 2;
+    const unsigned int height   = 3;
+    const unsigned int channels = 2;
+    const unsigned int num      = 1;
+
+    armnn::TensorInfo inputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>());
+    armnn::TensorInfo outputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>());
+    armnn::TensorInfo tensorInfo({channels}, armnn::GetDataType<T>());
+
+    // Set quantization parameters if the requested type is a quantized type.
+    if(armnn::IsQuantizedType<T>())
+    {
+        inputTensorInfo.SetQuantizationScale(qScale);
+        inputTensorInfo.SetQuantizationOffset(qOffset);
+        outputTensorInfo.SetQuantizationScale(qScale);
+        outputTensorInfo.SetQuantizationOffset(qOffset);
+        tensorInfo.SetQuantizationScale(qScale);
+        tensorInfo.SetQuantizationOffset(qOffset);
+    }
+
+    auto input = MakeTensor<T, 4>(inputTensorInfo,
+        QuantizedVector<T>(qScale, qOffset,
+        {
+            1.f, 1.f, 4.f, 1.f,
+            4.f, 4.f, 2.f, 1.f,
+            1.f, -2.f, 6.f, 4.f
+        }));
+    // These values are per-channel of the input.
+    auto mean     = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, -2}));
+    auto variance = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {4, 9}));
+    auto beta     = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, 2}));
+    auto gamma    = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {2, 1}));
+    LayerTestResult<T,4> ret(outputTensorInfo);
+
+    std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
+    std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+
+    armnn::BatchNormalizationQueueDescriptor data;
+    armnn::WorkloadInfo info;
+    armnn::ScopedCpuTensorHandle meanTensor(tensorInfo);
+    armnn::ScopedCpuTensorHandle varianceTensor(tensorInfo);
+    armnn::ScopedCpuTensorHandle betaTensor(tensorInfo);
+    armnn::ScopedCpuTensorHandle gammaTensor(tensorInfo);
+
+    AllocateAndCopyDataToITensorHandle(&meanTensor, &mean[0]);
+    AllocateAndCopyDataToITensorHandle(&varianceTensor, &variance[0]);
+    AllocateAndCopyDataToITensorHandle(&betaTensor, &beta[0]);
+    AllocateAndCopyDataToITensorHandle(&gammaTensor, &gamma[0]);
+
+    AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
+    AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
+    data.m_Mean             = &meanTensor;
+    data.m_Variance         = &varianceTensor;
+    data.m_Beta             = &betaTensor;
+    data.m_Gamma            = &gammaTensor;
+    data.m_Parameters.m_Eps = 0.0f;
+    data.m_Parameters.m_DataLayout = armnn::DataLayout::NHWC;
+
+    // For each channel:
+    // substract mean, divide by standard deviation (with an epsilon to avoid div by 0),
+    // multiply by gamma and add beta
+    ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo,
+        QuantizedVector<T>(qScale, qOffset,
+        {
+            1.f, 3.f, 4.f, 3.f,
+            4.f, 4.f, 2.f, 3.f,
+            1.f, 2.f, 6.f, 4.f
+        }));
+
+    std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateBatchNormalization(data, info);
+
+    inputHandle->Allocate();
+    outputHandle->Allocate();
+
+    CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+
+    workloadFactory.Finalize();
+    workload->Execute();
+
+    CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
+
+    return ret;
+}
\ No newline at end of file
diff --git a/src/backends/test/LayerTests.hpp b/src/backends/test/LayerTests.hpp
index b6651ce..925e3e6 100644
--- a/src/backends/test/LayerTests.hpp
+++ b/src/backends/test/LayerTests.hpp
@@ -256,6 +256,7 @@
 LayerTestResult<float, 4> ResizeBilinearMagNhwcTest(armnn::IWorkloadFactory& workloadFactory);
 
 LayerTestResult<float, 4> BatchNormTest(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<float, 4> BatchNormNhwcTest(armnn::IWorkloadFactory& workloadFactory);
 
 LayerTestResult<float, 2> FakeQuantizationTest(armnn::IWorkloadFactory& workloadFactory);