diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index f3cf544..51820a4 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -517,9 +517,16 @@
     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
 
+    TensorShape inputShape = (dataLayout == DataLayout::NCHW) ?
+                TensorShape{ 3, 5, 5, 1 } : TensorShape{ 3, 1, 5, 5 };
+    TensorShape outputShape = (dataLayout == DataLayout::NCHW) ?
+                TensorShape{ 3, 5, 5, 1 } : TensorShape{ 3, 1, 5, 5 };
+
     // Connects up.
-    Connect(input, layer, TensorInfo({3, 5, 5, 1}, DataType));
-    Connect(layer, output, TensorInfo({3, 5, 5, 1}, DataType));
+    armnn::TensorInfo inputTensorInfo(inputShape, DataType);
+    armnn::TensorInfo outputTensorInfo(outputShape, DataType);
+    Connect(input, layer, inputTensorInfo);
+    Connect(layer, output, outputTensorInfo);
     CreateTensorHandles(graph, factory);
 
     // Makes the workload and checks it.
diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp
index 526dc68..756b4a6 100644
--- a/src/backends/cl/test/ClCreateWorkloadTests.cpp
+++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp
@@ -337,17 +337,20 @@
 {
     Graph graph;
     ClWorkloadFactory factory;
-
-    auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>
-                    (factory, graph, dataLayout);
+    auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
 
     // Checks that inputs/outputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
     NormalizationQueueDescriptor queueDescriptor = workload->GetData();
     auto inputHandle  = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]);
     auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
 
-    BOOST_TEST(CompareIClTensorHandleShape(inputHandle, {3, 5, 5, 1}));
-    BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {3, 5, 5, 1}));
+    std::initializer_list<unsigned int> inputShape  = (dataLayout == DataLayout::NCHW) ?
+            std::initializer_list<unsigned int>({3, 5, 5, 1}) : std::initializer_list<unsigned int>({3, 1, 5, 5});
+    std::initializer_list<unsigned int> outputShape = (dataLayout == DataLayout::NCHW) ?
+            std::initializer_list<unsigned int>({3, 5, 5, 1}) : std::initializer_list<unsigned int>({3, 1, 5, 5});
+
+    BOOST_TEST(CompareIClTensorHandleShape(inputHandle, inputShape));
+    BOOST_TEST(CompareIClTensorHandleShape(outputHandle, outputShape));
 }
 
 BOOST_AUTO_TEST_CASE(CreateNormalizationFloat32NchwWorkload)
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
index 4b6ab51..a588a3e 100644
--- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp
+++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
@@ -296,8 +296,12 @@
     NormalizationQueueDescriptor queueDescriptor = workload->GetData();
     auto inputHandle  = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Inputs[0]);
     auto outputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Outputs[0]);
-    BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({3, 5, 5, 1}, DataType)));
-    BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({3, 5, 5, 1}, DataType)));
+
+    TensorShape inputShape  = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 5, 5, 1} : TensorShape{3, 1, 5, 5};
+    TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 5, 5, 1} : TensorShape{3, 1, 5, 5};
+
+    BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo(inputShape, DataType)));
+    BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo(outputShape, DataType)));
 }
 
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 236267c..1ec7749 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -231,21 +231,40 @@
 }
 
 template <typename NormalizationWorkloadType, armnn::DataType DataType>
-static void RefCreateNormalizationWorkloadTest()
+static void RefCreateNormalizationWorkloadTest(DataLayout dataLayout)
 {
     Graph graph;
     RefWorkloadFactory factory;
-    auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph);
+    auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+    TensorShape inputShape;
+    TensorShape outputShape;
+
+    switch (dataLayout)
+    {
+        case DataLayout::NHWC:
+            inputShape  = { 3, 1, 5, 5 };
+            outputShape = { 3, 1, 5, 5 };
+            break;
+        case DataLayout::NCHW:
+        default:
+            inputShape  = { 3, 5, 5, 1 };
+            outputShape = { 3, 5, 5, 1 };
+            break;
+    }
 
     // Checks that outputs and inputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
-    CheckInputOutput(std::move(workload),
-                     TensorInfo({3, 5, 5, 1}, DataType),
-                     TensorInfo({3, 5, 5, 1}, DataType));
+    CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
 }
 
 BOOST_AUTO_TEST_CASE(CreateRefNormalizationNchwWorkload)
 {
-    RefCreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, armnn::DataType::Float32>();
+    RefCreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateRefNormalizationNhwcWorkload)
+{
+    RefCreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
 }
 
 template <typename Pooling2dWorkloadType, armnn::DataType DataType>
