diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index 4ba2a9e..10e2a54 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -80,6 +80,9 @@
 ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DSimpleFloat32,
                                  ClContextControlFixture,
                                  BatchMatMul3DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMulNCHWSimpleFloat32,
+                                 ClContextControlFixture,
+                                 BatchMatMulNCHWSimpleTest<DataType::Float32>);
 ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBatchFloat32,
                                  ClContextControlFixture,
                                  BatchMatMul3DBatchTest<DataType::Float32>);
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
index ece87c2..f21666b 100644
--- a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
@@ -13,6 +13,7 @@
 #include <armnn/utility/PolymorphicDowncast.hpp>
 
 #include <armnnUtils/Permute.hpp>
+#include <armnnUtils/TensorUtils.hpp>
 
 #include <backendsCommon/WorkloadUtils.hpp>
 
@@ -24,6 +25,7 @@
 
 namespace armnn
 {
+
 arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
                                           const TensorInfo& inputY,
                                           const TensorInfo& output,
@@ -42,36 +44,41 @@
     arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
     arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
 
-    const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
-    const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
-    const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+    // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
+    // tensors so try to reduce the dimensions to 3
+    const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
+    const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
+    const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
 
     arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
     arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
 
     if (descriptor.m_TransposeX == true)
     {
-        auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
+        armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
+
+        auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
         const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
-        const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
-        aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
+        const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
+        aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
 
         statusPermuteX =  arm_compute::CLPermute::validate(&aclInputXInfo,
                                                            &aclPermutedXInfo,
                                                            aclPermutationXVector);
     }
 
-    if ( descriptor.m_TransposeY == true)
+    if (descriptor.m_TransposeY == true)
     {
-        auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
+        armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
+
+        auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
         const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
-        const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
-        aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
+        const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
+        aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
 
         statusPermuteY =  arm_compute::CLPermute::validate(&aclInputYInfo,
                                                            &aclPermutedYInfo,
                                                            aclPermutationYVector);
-
     }
 
     const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false,  // is inputX reshaped
@@ -133,16 +140,24 @@
     arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
 
     inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
+    arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
+            info.m_InputTensorInfos[0].GetShape(), 3);
+    inputX.info()->set_tensor_shape(inputXTensorInfo);
     inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
+    arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
+            info.m_InputTensorInfos[1].GetShape(), 3);
+    inputY.info()->set_tensor_shape(inputYTensorInfo);
 
     arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
     arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
 
     if (descriptor.m_Parameters.m_TransposeX == true)
     {
+        armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
+
         armnn::PermutationVector permutationXVector
-                = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
-        const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
+                = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
+        const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
         const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
         armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
         armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
@@ -157,9 +172,11 @@
 
     if (descriptor.m_Parameters.m_TransposeY == true)
     {
+        armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
+
         armnn::PermutationVector permutationYVector
-                = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[1].GetNumDimensions());
-        const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[1], permutationYVector);
+                = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
+        const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
         const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
         armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
         armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
