IVGCVSW-2040 Add unit tests for the newly implemented NHWC support in
ref BatchNormalization

 * Added create workload unit tests for the NHWC data layout

Change-Id: I03d66c88dc9b0340302b85012cb0152f0ec6fa72
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 51820a4..21385d7 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -133,11 +133,12 @@
 
 template <typename BatchNormalizationFloat32Workload, armnn::DataType DataType>
 std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkloadTest(
-    armnn::IWorkloadFactory& factory, armnn::Graph& graph)
+    armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW)
 {
     // Creates the layer we're testing.
     BatchNormalizationDescriptor layerDesc;
     layerDesc.m_Eps = 0.05f;
+    layerDesc.m_DataLayout = dataLayout;
 
     BatchNormalizationLayer* const layer = graph.AddLayer<BatchNormalizationLayer>(layerDesc, "layer");
 
@@ -155,16 +156,19 @@
     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
 
+    TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
+    TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 };
+
     // Connects up.
-    armnn::TensorInfo tensorInfo({2, 3, 1, 1}, DataType);
-    Connect(input, layer, tensorInfo);
-    Connect(layer, output, tensorInfo);
+    Connect(input, layer, TensorInfo(inputShape, DataType));
+    Connect(layer, output, TensorInfo(outputShape, DataType));
     CreateTensorHandles(graph, factory);
 
     // Makes the workload and checks it.
     auto workload = MakeAndCheckWorkload<BatchNormalizationFloat32Workload>(*layer, graph, factory);
     BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
     BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f);
+    BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout));
     BOOST_TEST(queueDescriptor.m_Inputs.size() == 1);
     BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
     BOOST_TEST((queueDescriptor.m_Mean->GetTensorInfo() == TensorInfo({3}, DataType)));
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 1ec7749..d258b81 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -143,16 +143,56 @@
                                      armnn::DataType::QuantisedAsymm8>();
 }
 
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationWorkload)
+template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
+static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
 {
-    Graph                graph;
+    Graph graph;
     RefWorkloadFactory factory;
-    auto workload = CreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload, armnn::DataType::Float32>
-                    (factory, graph);
+    auto workload =
+            CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+    TensorShape inputShape;
+    TensorShape outputShape;
+
+    switch (dataLayout)
+    {
+        case DataLayout::NHWC:
+            inputShape  = { 2, 1, 1, 3 };
+            outputShape = { 2, 1, 1, 3 };
+            break;
+        case DataLayout::NCHW:
+        default:
+            inputShape  = { 2, 3, 1, 1 };
+            outputShape = { 2, 3, 1, 1 };
+            break;
+    }
 
     // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
-    CheckInputOutput(
-        std::move(workload), TensorInfo({2, 3, 1, 1}, DataType::Float32), TensorInfo({2, 3, 1, 1}, DataType::Float32));
+    CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32Workload)
+{
+    RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload,armnn::DataType::Float32>
+            (DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32WorkloadNhwc)
+{
+    RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload, armnn::DataType::Float32>
+            (DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8Workload)
+{
+    RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8>
+            (DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8WorkloadNhwc)
+{
+    RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8>
+            (DataLayout::NHWC);
 }
 
 BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Float32Workload)