diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp
index e616ecf..a83fd62 100644
--- a/src/backends/backendsCommon/LayerSupportRules.hpp
+++ b/src/backends/backendsCommon/LayerSupportRules.hpp
@@ -186,4 +186,12 @@
     }
 };
 
+struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
+{
+    TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
+    {
+        m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
+    }
+};
+
 } //namespace armnn
\ No newline at end of file
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 606821b..9a4c60f 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -4143,5 +4143,232 @@
     }
 }
 
+void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+    const std::string descriptorName{"BatchMatMulDescriptor"};
+
+    ValidateNumInputs(workloadInfo,  descriptorName, 2);
+    ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+    // Inputs must be: both 2D+
+    // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
+    // axes N and I must be the same size
+
+    const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
+    const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
+    const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+    std::vector<DataType> supportedTypes =
+    {
+        DataType::BFloat16,
+        DataType::Float16,
+        DataType::Float32,
+        DataType::QAsymmS8,
+        DataType::QAsymmU8,
+        DataType::QSymmS16
+    };
+
+    ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
+    ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
+    ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+
+    if ((inputTensorXInfo.GetNumDimensions() < 2) ||
+        (inputTensorYInfo.GetNumDimensions() < 2))
+    {
+        throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
+    }
+
+    if(m_Parameters.m_DataLayoutX.has_value())
+    {
+        switch(m_Parameters.m_DataLayoutX.value())
+        {
+            case DataLayout::NCHW:
+            case DataLayout::NHWC:
+                if(inputTensorXInfo.GetNumDimensions() != 4)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor X does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            case DataLayout::NCDHW:
+            case DataLayout::NDHWC:
+                if(inputTensorXInfo.GetNumDimensions() != 5)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor X does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            default:
+                break;
+        }
+    }
+
+    if(m_Parameters.m_DataLayoutY.has_value())
+    {
+        switch(m_Parameters.m_DataLayoutY.value())
+        {
+            case DataLayout::NCHW:
+            case DataLayout::NHWC:
+                if(inputTensorYInfo.GetNumDimensions() != 4)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor Y does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            case DataLayout::NCDHW:
+            case DataLayout::NDHWC:
+                if(inputTensorYInfo.GetNumDimensions() != 5)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor Y does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            default:
+                break;
+        }
+    }
+
+    auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
+                                                         inputTensorXInfo.GetShape(),
+                                                         inputTensorYInfo.GetShape());
+
+    if(inputTensorXInfo.GetShape()[axesToMul.first.second]
+       != inputTensorYInfo.GetShape()[axesToMul.second.first])
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": The final axis of input tensor X must be the same size as "
+            "the second last axis of input tensor Y.");
+    }
+
+    auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
+                                                           inputTensorXInfo.GetShape(),
+                                                           inputTensorYInfo.GetShape());
+
+    {   // Separate scope so we don't pollute the rest of the scope with our temp variables
+        // e.g. NHWC isnt compatible with NCHW as of now
+        DataLayout xLayout;
+        DataLayout yLayout;
+
+        if(m_Parameters.m_DataLayoutX == EmptyOptional())
+        {
+            xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
+        }
+        else
+        {
+            xLayout = m_Parameters.m_DataLayoutX.value();
+        }
+
+        if(m_Parameters.m_DataLayoutY == EmptyOptional())
+        {
+            yLayout = DataLayout::NCHW;
+        }
+        else
+        {
+            yLayout = m_Parameters.m_DataLayoutY.value();
+        }
+
+        if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
+        {
+            if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
+            {
+                throw InvalidArgumentException(descriptorName +
+                    ": Invalid input tensor data layout combination.");
+            }
+        }
+        if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
+        {
+            if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
+            {
+                throw InvalidArgumentException(descriptorName +
+                    ": Invalid input tensor data layout combination.");
+            }
+        }
+    }
+
+    // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
+    unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
+                                                inputTensorYInfo.GetNumDimensions());
+    if(outputTensorDimSize-2 > 0)
+    {
+        TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+                                          DataType::Float32);
+        TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+                                          DataType::Float32);
+        TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+                                            DataType::Float32);
+
+        auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
+        {
+            auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
+
+            for(unsigned int i = 0; i < sizeDiff; i++)
+            {
+                axisIndices.insert(axisIndices.begin(), 1);
+            }
+
+            for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
+            {
+                ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+            }
+        };
+
+        doAxisExtension(axesNotMul.first, tiXNotMul);
+        doAxisExtension(axesNotMul.second, tiYNotMul);
+
+        for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
+        {
+            tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
+                                                 tiYNotMul.GetShape()[i]);
+        }
+
+        ValidateBroadcastTensorShapesMatch(tiXNotMul,
+                                           tiYNotMul,
+                                           tiOutNotMul,
+                                           descriptorName,
+                                           "input_X",
+                                           "input_Y");
+    }
+
+    // Also check descriptor parameter validity
+    // This will eventually be moved to the start of the function as explained below
+    if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
+        (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameters - Transpose and Adjoint "
+            "vectors cannot both be true for a given input tensor.");
+    }
+
+    if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Transpose X vector must be "
+            "the same size as tensor input X's dimensionality.");
+    }
+    if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Adjoint X vector must be "
+            "the same size as tensor input X's dimensionality.");
+    }
+    if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Transpose Y vector must be "
+            "the same size as tensor input Y's dimensionality.");
+    }
+    if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Adjoint Y vector must be "
+            "the same size as tensor input Y's dimensionality.");
+    }
+    // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
+}
+
 
 } // namespace armnn
\ No newline at end of file
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 3660e6e..70006e4 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -133,6 +133,22 @@
                     reason);
             break;
         }
+        case LayerType::BatchMatMul:
+        {
+            auto cLayer = PolymorphicDowncast<const BatchMatMulLayer*>(&layer);
+            const BatchMatMulDescriptor& descriptor = cLayer->GetParameters();
+
+            const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+            const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+            const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+            result = layerSupportObject.IsBatchMatMulSupported(
+                            OverrideDataType(input0, dataType),
+                            OverrideDataType(input1, dataType),
+                            OverrideDataType(output, dataType),
+                            descriptor,
+                            reason);
+            break;
+        }
         case LayerType::BatchNormalization:
         {
             auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk
index 86de7e3..007cca5 100644
--- a/src/backends/backendsCommon/common.mk
+++ b/src/backends/backendsCommon/common.mk
@@ -46,6 +46,7 @@
     test/layerTests/ActivationTestImpl.cpp \
     test/layerTests/AdditionTestImpl.cpp \
     test/layerTests/ArgMinMaxTestImpl.cpp \
+    test/layerTests/BatchMatMulTestImpl.cpp \
     test/layerTests/BatchNormalizationTestImpl.cpp \
     test/layerTests/CastTestImpl.cpp \
     test/layerTests/ChannelShuffleTestImpl.cpp \
diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt
index 8beb7c4..c5b97eb 100644
--- a/src/backends/backendsCommon/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/test/CMakeLists.txt
@@ -68,6 +68,8 @@
     layerTests/AdditionTestImpl.hpp
     layerTests/ArgMinMaxTestImpl.cpp
     layerTests/ArgMinMaxTestImpl.hpp
+    layerTests/BatchMatMulTestImpl.cpp
+    layerTests/BatchMatMulTestImpl.hpp
     layerTests/BatchNormalizationTestImpl.cpp
     layerTests/BatchNormalizationTestImpl.hpp
     layerTests/BatchToSpaceNdTestImpl.hpp
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index ba8cfd5..5fdcd9c 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -614,6 +614,8 @@
 
 DECLARE_LAYER_POLICY_2_PARAM(ArgMinMax)
 
+DECLARE_LAYER_POLICY_2_PARAM(BatchMatMul)
+
 DECLARE_LAYER_POLICY_2_PARAM(BatchNormalization)
 
 DECLARE_LAYER_POLICY_2_PARAM(BatchToSpaceNd)
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index 8d73027..25435b2 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -9,6 +9,7 @@
 #include <backendsCommon/test/layerTests/ActivationTestImpl.hpp>
 #include <backendsCommon/test/layerTests/AdditionTestImpl.hpp>
 #include <backendsCommon/test/layerTests/ArgMinMaxTestImpl.hpp>
+#include <backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp>
 #include <backendsCommon/test/layerTests/BatchNormalizationTestImpl.hpp>
 #include <backendsCommon/test/layerTests/BatchToSpaceNdTestImpl.hpp>
 #include <backendsCommon/test/layerTests/CastTestImpl.hpp>
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
new file mode 100644
index 0000000..41add6e
--- /dev/null
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
@@ -0,0 +1,1010 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "BatchMatMulTestImpl.hpp"
+
+#include <armnn/backends/IBackendInternal.hpp>
+#include <armnn/backends/Workload.hpp>
+#include <armnn/backends/WorkloadData.hpp>
+#include <armnn/backends/WorkloadFactory.hpp>
+
+#include <armnnTestUtils/WorkloadTestUtils.hpp>
+#include <armnnUtils/QuantizeHelper.hpp>
+#include <armnnTestUtils/TensorCopyUtils.hpp>
+#include <armnn/Optional.hpp>
+
+
+template<armnn::DataType ArmnnType, typename T, std::size_t NumDims>
+LayerTestResult<T, NumDims> BatchMatMulTestImpl(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory,
+    armnn::BatchMatMulDescriptor descriptor,
+    const std::vector<T>& inputX,
+    const std::vector<T>& inputY,
+    const std::vector<T>& outputExpected,
+    const armnn::TensorInfo& inputXInfo,
+    const armnn::TensorInfo& inputYInfo,
+    const armnn::TensorInfo& outputInfo)
+{
+    std::vector<T> outputActual(outputInfo.GetNumElements());
+
+    std::unique_ptr<armnn::ITensorHandle> inputXHandle = tensorHandleFactory.CreateTensorHandle(inputXInfo);
+    std::unique_ptr<armnn::ITensorHandle> inputYHandle = tensorHandleFactory.CreateTensorHandle(inputYInfo);
+    std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
+
+    armnn::BatchMatMulQueueDescriptor queueDescriptor;
+    queueDescriptor.m_Parameters = descriptor;
+    armnn::WorkloadInfo workloadInfo;
+
+    AddInputToWorkload(queueDescriptor, workloadInfo, inputXInfo, inputXHandle.get());
+    AddInputToWorkload(queueDescriptor, workloadInfo, inputYInfo, inputYHandle.get());
+    AddOutputToWorkload(queueDescriptor, workloadInfo, outputInfo, outputHandle.get());
+
+    auto workload = workloadFactory.CreateWorkload(armnn::LayerType::BatchMatMul, queueDescriptor, workloadInfo);
+
+    inputXHandle->Allocate();
+    inputYHandle->Allocate();
+    outputHandle->Allocate();
+
+    CopyDataToITensorHandle(inputXHandle.get(), inputX.data());
+    CopyDataToITensorHandle(inputYHandle.get(), inputY.data());
+
+    workload->PostAllocationConfigure();
+    ExecuteWorkload(*workload, memoryManager);
+
+    CopyDataFromITensorHandle(outputActual.data(), outputHandle.get());
+
+    return LayerTestResult<T, NumDims>(outputActual,
+                                       outputExpected,
+                                       outputHandle->GetShape(),
+                                       outputInfo.GetShape());
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, 2,
+       3, 4
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        5, 6,
+        7, 8
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        19, 22,
+        43, 50
+    }, qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+                                             memoryManager,
+                                             tensorHandleFactory,
+                                             descriptor,
+                                             inputX,
+                                             inputY,
+                                             outputExpected,
+                                             inputXInfo,
+                                             inputYInfo,
+                                             outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({1,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({1,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({1,2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, 2,
+       3, 4
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        5, 6,
+        7, 8
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        19, 22,
+        43, 50
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(
+        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW),
+        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW));
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, 2,
+       3, 4
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        5, 6,
+        7, 8
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        19, 22,
+        43, 50
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(
+        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC),
+        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, 2,
+       3, 4
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        5, 6,
+        7, 8
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        19, 22,
+        43, 50
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DBatchTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({2,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, 2,
+       3, 4,
+
+       9, 10,
+       11, 12
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        5, 6,
+        7, 8,
+
+        13, 14,
+        15, 16
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        19, 22,
+        43, 50,
+
+        267, 286,
+        323, 346
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({1,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, 2,
+       3, 4,
+
+       9, 10,
+       11, 12
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        13, 14,
+        15, 16
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        43, 46,
+        99, 106,
+
+        267, 286,
+        323, 346
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, 2,
+       3, 4,
+
+       9, 10,
+       11, 12
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        13, 14,
+        15, 16
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        43, 46,
+        99, 106,
+
+        267, 286,
+        323, 346
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(
+        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NDHWC),
+        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({1,1,2,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({1,2,2,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({1,1,2,2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+        1, 20,
+        3, 22,
+
+        2, 21,
+        4, 23
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        5, 24,
+        7, 26,
+
+        6, 25,
+        8, 27
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+       23, 1030,
+       31, 1114,
+
+       34, 1079,
+       46, 1167
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 5>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DTinyTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({1,1}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({1,1}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({1,1}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       3
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        5
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        15
+    }, qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+                                             memoryManager,
+                                             tensorHandleFactory,
+                                             descriptor,
+                                             inputX,
+                                             inputY,
+                                             outputExpected,
+                                             inputXInfo,
+                                             inputYInfo,
+                                             outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({2,5,3}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({2,3,4}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({2,5,4}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       8, 8, 4,
+       6, 1, 3,
+       8, 8, 3,
+       8, 9, 8,
+       5, 4, 4,
+
+       1, 8, 5,
+       7, 1, 1,
+       8, 7, 9,
+       3, 2, 7,
+       8, 5, 3
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        6, 2, 3, 2,
+        6, 2, 2, 8,
+        3, 7, 8, 1,
+
+        7, 2, 9, 5,
+        2, 3, 1, 3,
+        2, 7, 7, 5
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        108, 60, 72, 84,
+        51, 35, 44, 23,
+        105, 53, 64, 83,
+        126, 90, 106, 96,
+        66, 46, 55, 46,
+
+        33, 61, 52, 54,
+        53, 24, 71, 43,
+        88, 100, 142, 106,
+        39, 61, 78, 56,
+        72, 52, 98, 70
+    },qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
\ No newline at end of file
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
new file mode 100644
index 0000000..9e21396
--- /dev/null
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
@@ -0,0 +1,85 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnnTestUtils/LayerTestResult.hpp>
+
+#include <ResolveType.hpp>
+
+#include <armnn/backends/IBackendInternal.hpp>
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>, std::size_t NumDims>
+LayerTestResult<T, NumDims> BatchMatMulTestImpl(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory,
+    armnn::BatchMatMulDescriptor descriptor,
+    const std::vector<T>& inputX,
+    const std::vector<T>& inputY,
+    const std::vector<T>& outputExpected,
+    const armnn::TensorInfo& inputXInfo,
+    const armnn::TensorInfo& inputYInfo,
+    const armnn::TensorInfo& outputInfo);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DBatchTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DTinyTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
\ No newline at end of file
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 8051dcf..4090901 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -79,6 +79,12 @@
                                         infos[1],
                                         *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
                                         reasonIfUnsupported);
+        case LayerType::BatchMatMul:
+            return IsBatchMatMulSupported(infos[0],
+                                          infos[1],
+                                          infos[2],
+                                          *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
+                                          reasonIfUnsupported);
         case LayerType::BatchNormalization:
             return IsBatchNormalizationSupported(infos[0],
                                                  infos[1],
@@ -642,6 +648,52 @@
     return supported;
 }
 
+bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
+                                             const TensorInfo& inputY,
+                                             const TensorInfo& output,
+                                             const BatchMatMulDescriptor& descriptor,
+                                             Optional<std::string &> reasonIfUnsupported) const
+{
+    IgnoreUnused(descriptor);
+
+    std::array<DataType, 6> supportedTypes =
+    {
+        DataType::BFloat16,
+        DataType::Float16,
+        DataType::Float32,
+        DataType::QAsymmS8,
+        DataType::QAsymmU8,
+        DataType::QSymmS16
+    };
+
+    bool supported = true;
+
+    supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
+                                  "Reference batch matrix multiplication: input X is not a supported type");
+
+    supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
+                                  "Reference batch matrix multiplication: input Y is not a supported type");
+
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+                                  "Reference batch matrix multiplication: output is not a supported type");
+
+    supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
+                                  "Reference batch matrix multiplication: input X and input Y types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
+                                  "Reference batch matrix multiplication: inputs and output types are mismatched");
+
+    supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
+                                  reasonIfUnsupported,
+                                  "Reference batch matrix multiplication: input X is not of rank 2 or greater");
+
+    supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
+                                  reasonIfUnsupported,
+                                  "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
+
+    return supported;
+}
+
 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
                                                     const TensorInfo& output,
                                                     const TensorInfo& mean,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index aa8bd8d..b64244d 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -34,6 +34,12 @@
                               const ArgMinMaxDescriptor& descriptor,
                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
+    bool IsBatchMatMulSupported(const TensorInfo& inputX,
+                                const TensorInfo& inputY,
+                                const TensorInfo& output,
+                                const BatchMatMulDescriptor& descriptor,
+                                Optional<std::string &> reasonIfUnsupported = EmptyOptional()) const;
+
     bool IsBatchNormalizationSupported(const TensorInfo& input,
                                        const TensorInfo& output,
                                        const TensorInfo& mean,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 2d95658..093d0d5 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -170,6 +170,11 @@
             auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
             return std::make_unique<RefArgMinMaxWorkload>(*argMinMaxQueueDescriptor, info);
         }
+        case LayerType::BatchMatMul:
+        {
+            auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor);
+            return std::make_unique<RefBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info);
+        }
         case LayerType::BatchNormalization :
         {
             auto batchNormQueueDescriptor = PolymorphicDowncast<const BatchNormalizationQueueDescriptor*>(&descriptor);
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index d9a5a1d..ed942e6 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -23,6 +23,7 @@
         RefTensorHandleFactory.cpp \
         workloads/Activation.cpp \
         workloads/ArgMinMax.cpp \
+        workloads/BatchMatMulImpl.cpp \
         workloads/BatchNormImpl.cpp \
         workloads/BatchToSpaceNd.cpp \
         workloads/Broadcast.cpp \
@@ -49,6 +50,7 @@
         workloads/Reduce.cpp \
         workloads/RefActivationWorkload.cpp \
         workloads/RefArgMinMaxWorkload.cpp \
+        workloads/RefBatchMatMulWorkload.cpp \
         workloads/RefBatchNormalizationWorkload.cpp \
         workloads/RefBatchToSpaceNdWorkload.cpp \
         workloads/RefCastWorkload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 419ae2b..593dc78 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1062,6 +1062,77 @@
 ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1DVectorInt32, MultiplicationBroadcast1DVectorInt32Test)
 ARMNN_AUTO_TEST_CASE_WITH_THF(Multiplication5d, Multiplication5dTest)
 
+// Batch Mat Mul
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleBFloat16, BatchMatMul2DSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat32, BatchMatMul2DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat16, BatchMatMul2DSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmS8, BatchMatMul2DSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmU8, BatchMatMul2DSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQASymmS16, BatchMatMul2DSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleBFloat16, BatchMatMul3DSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat32, BatchMatMul3DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat16, BatchMatMul3DSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmS8, BatchMatMul3DSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmU8, BatchMatMul3DSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQASymmS16, BatchMatMul3DSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleBFloat16, BatchMatMulNCHWSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat32, BatchMatMulNCHWSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat16, BatchMatMulNCHWSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmS8, BatchMatMulNCHWSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmU8, BatchMatMulNCHWSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQASymmS16, BatchMatMulNCHWSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleBFloat16, BatchMatMulNHWCSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat32, BatchMatMulNHWCSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat16, BatchMatMulNHWCSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmS8, BatchMatMulNHWCSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmU8, BatchMatMulNHWCSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQASymmS16, BatchMatMulNHWCSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchBFloat16, BatchMatMul3DBatchTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat32, BatchMatMul3DBatchTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat16, BatchMatMul3DBatchTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmS8, BatchMatMul3DBatchTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmU8, BatchMatMul3DBatchTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQASymmS16, BatchMatMul3DBatchTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastBFloat16, BatchMatMul3DBroadcastTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat32, BatchMatMul3DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat16, BatchMatMul3DBroadcastTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmS8, BatchMatMul3DBroadcastTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmU8, BatchMatMul3DBroadcastTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQASymmS16, BatchMatMul3DBroadcastTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastBFloat16, BatchMatMul3D2DBroadcastTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat32, BatchMatMul3D2DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat16, BatchMatMul3D2DBroadcastTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmS8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmU8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQASymmSS16, BatchMatMul3D2DBroadcastTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCBFloat16, BatchMatMulNDHWCNHWCTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat32, BatchMatMulNDHWCNHWCTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat16, BatchMatMulNDHWCNHWCTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmS8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmU8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQASymmSS16, BatchMatMulNDHWCNHWCTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyBFloat16, BatchMatMul2DTinyTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat32, BatchMatMul2DTinyTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat16, BatchMatMul2DTinyTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmS8, BatchMatMul2DTinyTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmU8, BatchMatMul2DTinyTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQASymmS16, BatchMatMul2DTinyTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareBFloat16, BatchMatMul3DNonSquareTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat32, BatchMatMul3DNonSquareTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat16, BatchMatMul3DNonSquareTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSquareTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest<DataType::QSymmS16>);
+
 // Batch Norm
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test)
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest)
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
new file mode 100644
index 0000000..74a358c
--- /dev/null
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -0,0 +1,230 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "BatchMatMulImpl.hpp"
+
+#include <armnn/backends/WorkloadData.hpp>
+#include <armnn/Logging.hpp>
+
+namespace armnn
+{
+
+void BatchMatMul::BatchMatMulImpl()
+{
+    inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
+    inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
+    // At this point, we don't touch the input decoders - just the resultant vectors
+
+    // Pre-transpose and pre-adjoint if their vectors aren't empty
+    // and also DataLayouts which may change with permutations/adjoints
+
+    // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
+
+    auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
+    RecurseBMM(idx, 0);
+}
+
+void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
+{
+    // We're working off of the indexes of the output tensor (the max possible shape)
+
+    if(!(curDim < outputInfo.GetNumDimensions()))
+    {
+        // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+
+        auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
+                                                             inputXInfo.GetShape(),
+                                                             inputYInfo.GetShape());
+        AdjustAxesToMulForUnequalRanks(axesToMul);
+
+        unsigned int inputXColDim = axesToMul.first.second;
+        unsigned int inputYRowDim = axesToMul.second.first;
+
+        unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+
+        float sum = 0.0f;
+
+        // You could also use inputXColSize
+        for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
+            auto xIdx = curIdx;
+            xIdx[inputXColDim] = inputYRowIdx;
+
+            auto yIdx = curIdx;
+            yIdx[inputYRowDim] = inputYRowIdx;
+
+            sum += (GetValueAt(DataSlot::InputX, xIdx)
+                  * GetValueAt(DataSlot::InputY, yIdx));
+        }
+
+        SetValueAt(sum, DataSlot::Output, curIdx);
+
+        return;
+    }
+
+    for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
+    {
+        curIdx[curDim] = i;
+        RecurseBMM(curIdx, curDim+1);
+    }
+}
+
+void BatchMatMul::AdjustAxesToMulForUnequalRanks(
+    std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
+{
+    long rankDiff = static_cast<long>(inputXInfo.GetNumDimensions()) - inputYInfo.GetNumDimensions();
+    if(rankDiff == 0)
+    {
+        return;
+    }
+    else if(rankDiff < 0)
+    {
+        // Y is the larger one
+        axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+        axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+    }
+    else if(rankDiff > 0)
+    {
+        // X is the larger one
+        axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+        axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+    }
+}
+
+float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
+{
+    // This gets the data from the input vector that we have, Not the decoder
+    // But for the output, it is operating on the encoder itself
+
+    AdjustToSafeIdx(type, idx);
+    unsigned int flatIdx = CalcFlatIdx(type, idx);
+    float value = 0.0f;
+
+    switch(type)
+    {
+        case DataSlot::InputX:
+            value = inputXData[flatIdx];
+            break;
+        case DataSlot::InputY:
+            value = inputYData[flatIdx];
+            break;
+        case DataSlot::Output:
+            outputEncoder[flatIdx];
+            value = outputEncoder.Get();
+            break;
+        default:
+            break;
+    }
+
+    return value;
+}
+
+void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
+{
+    AdjustToSafeIdx(type, idx);
+
+    unsigned int flatIdx = CalcFlatIdx(type, idx);
+
+    switch(type)
+    {
+        case DataSlot::InputX:
+            inputXData[flatIdx] = value;
+            break;
+        case DataSlot::InputY:
+            inputYData[flatIdx] = value;
+            break;
+        case DataSlot::Output:
+            outputEncoder[flatIdx];
+            outputEncoder.Set(value);
+            break;
+        default:
+            break;
+    }
+}
+
+void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
+{
+    for(unsigned int dim = 0; dim < idx.size(); dim++)
+    {
+        switch(type)
+        {
+            case DataSlot::InputX:
+            {
+                auto xRank = inputXInfo.GetNumDimensions();
+                auto xDiff = outputInfo.GetNumDimensions() - xRank;
+                if (dim < xDiff ||
+                    idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
+                {
+                    idx[dim] = 0; // Broadcasting
+                }
+                break;
+            }
+            case DataSlot::InputY:
+            {
+                auto yRank = inputYInfo.GetNumDimensions();
+                auto yDiff = outputInfo.GetNumDimensions() - yRank;
+                if (dim < yDiff ||
+                    idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
+                {
+                    idx[dim] = 0;
+                }
+                break;
+            }
+            case DataSlot::Output:
+            {
+                // Our indices are based off the output
+                break;
+            }
+            default:
+                break;
+        }
+    }
+}
+
+unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
+{
+    unsigned int result = idx[idx.size()-1];
+
+    unsigned int dimMultiplier = 1;
+
+    unsigned int offset;
+
+    // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
+    for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
+    {
+        switch(type)
+        {
+            case DataSlot::InputX:
+                offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
+                dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
+                break;
+            case DataSlot::InputY:
+                offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
+                dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
+                break;
+            case DataSlot::Output:
+                dimMultiplier *= outputInfo.GetShape()[i+1];
+                break;
+            default:
+                break;
+        }
+        result += (idx[i] * dimMultiplier);
+    }
+    return result;
+}
+
+template <typename T>
+std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
+{
+    std::string res = "{ ";
+    for(auto x : vec)
+    {
+        res += std::to_string(x);
+        res += " ";
+    }
+    res += "}";
+    return res;
+}
+
+} // namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp
new file mode 100644
index 0000000..25b6c85
--- /dev/null
+++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp
@@ -0,0 +1,75 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+
+#include <armnn/backends/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class BatchMatMul {
+public:
+    enum DataSlot
+    {
+        InputX = 0,
+        InputY = 1,
+        Output = 2
+    };
+
+    BatchMatMul(const BatchMatMulDescriptor& params,
+                const TensorInfo& inputXInfo,
+                const TensorInfo& inputYInfo,
+                const TensorInfo& outputInfo,
+                Decoder<float>& inputXDecoder,
+                Decoder<float>& inputYDecoder,
+                Encoder<float>& outputEncoder)
+        : params(params),
+          inputXInfo(inputXInfo),
+          inputYInfo(inputYInfo),
+          outputInfo(outputInfo),
+          inputXDecoder(inputXDecoder),
+          inputYDecoder(inputYDecoder),
+          outputEncoder(outputEncoder)
+    {}
+
+    void BatchMatMulImpl();
+
+    void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
+
+    // Adjusts it for when input tensors are of unequal rank
+    void AdjustAxesToMulForUnequalRanks(
+        std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
+
+    float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
+
+    void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
+
+    // Takes into account broadcasting
+    void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
+
+    unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
+
+    template <typename T>
+    std::string StringifyVec(const std::vector<T>& vec);
+
+private:
+    const BatchMatMulDescriptor& params;
+    const TensorInfo& inputXInfo;
+    const TensorInfo& inputYInfo;
+    const TensorInfo& outputInfo;
+    Decoder<float>& inputXDecoder;
+    Decoder<float>& inputYDecoder;
+    Encoder<float>& outputEncoder;
+
+    std::vector<float> inputXData;
+    std::vector<float> inputYData;
+
+};
+
+} // namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index b1f6d8b..b8835e3 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -10,6 +10,8 @@
     ArgMinMax.cpp
     ArgMinMax.hpp
     BaseIterator.hpp
+    BatchMatMulImpl.cpp
+    BatchMatMulImpl.hpp
     BatchNormImpl.cpp
     BatchNormImpl.hpp
     BatchToSpaceNd.cpp
@@ -69,6 +71,8 @@
     RefArgMinMaxWorkload.cpp
     RefArgMinMaxWorkload.hpp
     RefBaseWorkload.hpp
+    RefBatchMatMulWorkload.cpp
+    RefBatchMatMulWorkload.hpp
     RefBatchNormalizationWorkload.cpp
     RefBatchNormalizationWorkload.hpp
     RefBatchToSpaceNdWorkload.cpp
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
new file mode 100644
index 0000000..388190c
--- /dev/null
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
@@ -0,0 +1,59 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefBatchMatMulWorkload.hpp"
+
+#include "BatchMatMulImpl.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+RefBatchMatMulWorkload::RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info)
+    : RefBaseWorkload(descriptor, info)
+{}
+
+void RefBatchMatMulWorkload::Execute() const
+{
+    Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
+
+void RefBatchMatMulWorkload::ExecuteAsync(ExecutionData& executionData)
+{
+    WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+    Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+}
+
+void RefBatchMatMulWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+{
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchMatMulWorkload_Execute");
+
+    const TensorInfo& inputXInfo = GetTensorInfo(inputs[0]);
+    const TensorInfo& inputYInfo = GetTensorInfo(inputs[1]);
+    const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
+    std::unique_ptr<Decoder<float>> inputXDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]),
+                                                                       inputs[0]->Map());
+
+    std::unique_ptr<Decoder<float>> inputYDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]),
+                                                                       inputs[1]->Map());
+
+    std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
+                                                                       outputs[0]->Map());
+
+    auto bmm = BatchMatMul(m_Data.m_Parameters,
+                           inputXInfo,
+                           inputYInfo,
+                           outputInfo,
+                           *inputXDecoder,
+                           *inputYDecoder,
+                           *outputEncoder);
+
+    bmm.BatchMatMulImpl();
+
+}
+
+} // namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp
new file mode 100644
index 0000000..e9dfcae
--- /dev/null
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp
@@ -0,0 +1,30 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "RefBaseWorkload.hpp"
+#include <armnn/backends/WorkloadData.hpp>
+
+#include "BatchMatMulImpl.hpp"
+
+namespace armnn
+{
+
+class RefBatchMatMulWorkload : public RefBaseWorkload<BatchMatMulQueueDescriptor>
+{
+public:
+    explicit RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+                                    const WorkloadInfo& info);
+
+    void Execute() const override;
+    void ExecuteAsync(ExecutionData& executionData) override;
+
+private:
+    void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
+
+};
+
+} // namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index b9c7a2a..e049d8d 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -7,6 +7,7 @@
 
 #include "RefActivationWorkload.hpp"
 #include "RefArgMinMaxWorkload.hpp"
+#include "RefBatchMatMulWorkload.hpp"
 #include "RefBatchNormalizationWorkload.hpp"
 #include "RefBatchToSpaceNdWorkload.hpp"
 #include "RefCastWorkload.hpp"
