IVGCVSW-7105: BatchMatMul Optional Parameter Support

  * Added transpose parameters to pre-transpose each input tensor's slices
  * Added adjoint parameters to pre-adjoint each input tensor's slices
  * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor)
  * Updated input validation and output shape inference for parameters
  * Additional layer unit tests for parameters added
  * Versionings incremented

Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667
diff --git a/InstallationViaAptRepository.md b/InstallationViaAptRepository.md
index 3fa36f6..037e5cc 100644
--- a/InstallationViaAptRepository.md
+++ b/InstallationViaAptRepository.md
@@ -117,7 +117,7 @@
  sudo apt-get install -y python3-pyarmnn armnn-latest-all
  # Verify installation via python:
  python3 -c "import pyarmnn as ann;print(ann.GetVersion())"
- # Returns '{ARMNN_MAJOR_VERSION}.0.0' e.g. 30.0.0
+ # Returns '{ARMNN_MAJOR_VERSION}.0.0' e.g. 31.0.0
 ```
 This will install PyArmNN and the three backends for Neon (CpuAcc), OpenCL (GpuAcc) and our Reference Backend.
 It will also install their dependencies including the arm-compute-library package along with the Tensorflow Lite Parser
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 38e3c61..493ce65 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #pragma once
@@ -1553,55 +1553,74 @@
 /// A BatchMatMulDescriptor for the BatchMatMul operator
 struct BatchMatMulDescriptor : BaseDescriptor
 {
-    BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(),
-                          Optional<DataLayout> dataLayoutY = EmptyOptional(),
-                          std::vector<unsigned int> transposeX = {},
-                          std::vector<unsigned int> transposeY = {},
-                          std::vector<unsigned int> adjointX = {},
-                          std::vector<unsigned int> adjointY = {})
-        : m_DataLayoutX(dataLayoutX)
-        , m_DataLayoutY(dataLayoutY)
-        , m_TransposeX(transposeX)
+    BatchMatMulDescriptor(bool transposeX = false,
+                          bool transposeY = false,
+                          bool adjointX = false,
+                          bool adjointY = false,
+                          DataLayout dataLayoutX = DataLayout::NCHW,
+                          DataLayout dataLayoutY = DataLayout::NCHW)
+        : m_TransposeX(transposeX)
         , m_TransposeY(transposeY)
         , m_AdjointX(adjointX)
         , m_AdjointY(adjointY)
+        , m_DataLayoutX(dataLayoutX)
+        , m_DataLayoutY(dataLayoutY)
     {}
 
     bool operator ==(const BatchMatMulDescriptor &rhs)  const
     {
-        return m_DataLayoutX == rhs.m_DataLayoutX &&
-               m_DataLayoutY == rhs.m_DataLayoutY &&
-               m_TransposeX == rhs.m_TransposeX &&
+        return m_TransposeX == rhs.m_TransposeX &&
                m_TransposeY == rhs.m_TransposeY &&
                m_AdjointX == rhs.m_AdjointX &&
-               m_AdjointY == rhs.m_AdjointY;
+               m_AdjointY == rhs.m_AdjointY &&
+               m_DataLayoutX == rhs.m_DataLayoutX &&
+               m_DataLayoutY == rhs.m_DataLayoutY;
     }
 
-    /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout)
-    Optional<DataLayout> m_DataLayoutX;
-    Optional<DataLayout> m_DataLayoutY;
-
-    /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing)
+    /// Transpose the slices of each input tensor
     /// Transpose and Adjoint can not both be set to true for the same tensor at the same time
-    std::vector<unsigned int> m_TransposeX;
-    std::vector<unsigned int> m_TransposeY;
+    bool m_TransposeX;
+    bool m_TransposeY;
 
-    /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint)
+    /// Adjoint the slices of each input tensor
     /// Transpose and Adjoint can not both be set to true for the same tensor at the same time
-    std::vector<unsigned int> m_AdjointX;
-    std::vector<unsigned int> m_AdjointY;
+    bool m_AdjointX;
+    bool m_AdjointY;
 
-    /// Static helper to get the two axes (for each input) for multiplication
+    /// Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
+    DataLayout m_DataLayoutX;
+    DataLayout m_DataLayoutY;
+
+    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable "
+                                      "GetAxesToMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.",
+                                      "23.05")
     static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul(
         const BatchMatMulDescriptor& desc,
         const TensorShape& tensorXShape,
         const TensorShape& tensorYShape);
 
-    /// Static helper to get the axes (for each input) that will not be multiplied together
+    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable "
+                                      "GetAxesNotMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.",
+                                      "23.05")
     static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul(
         const BatchMatMulDescriptor& desc,
         const TensorShape& inputXShape,
         const TensorShape& inputYShape);
+
+    /// Static helper to get the two axes (for each input) for multiplication
+    static std::pair<unsigned int, unsigned int> GetAxesToMul(
+        DataLayout dataLayout,
+        const TensorShape& tensorShape);
+
+    /// Static helper to get the axes (for each input) that will not be multiplied together
+    static std::vector<unsigned int> GetAxesNotMul(
+        DataLayout dataLayout,
+        const TensorShape& tensorShape);
+
+    /// Static helper to get the axes which will be transposed
+    static PermutationVector GetPermuteVec(
+        DataLayout dataLayout,
+        const TensorShape& tensorShape);
 };
 
 } // namespace armnn
diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp
index 7951eac..7fdb20a 100644
--- a/include/armnn/Version.hpp
+++ b/include/armnn/Version.hpp
@@ -10,7 +10,7 @@
 #define STRINGIFY_MACRO(s) #s
 
 // ArmNN version components
-#define ARMNN_MAJOR_VERSION 30
+#define ARMNN_MAJOR_VERSION 31
 #define ARMNN_MINOR_VERSION 0
 #define ARMNN_PATCH_VERSION 0
 
diff --git a/include/armnnOnnxParser/Version.hpp b/include/armnnOnnxParser/Version.hpp
index 33a2846..5fbaf0c 100644
--- a/include/armnnOnnxParser/Version.hpp
+++ b/include/armnnOnnxParser/Version.hpp
@@ -14,7 +14,7 @@
 
 // OnnxParser version components
 #define ONNX_PARSER_MAJOR_VERSION 24
-#define ONNX_PARSER_MINOR_VERSION 5
+#define ONNX_PARSER_MINOR_VERSION 6
 #define ONNX_PARSER_PATCH_VERSION 0
 
 /// ONNX_PARSER_VERSION: "X.Y.Z"
diff --git a/include/armnnTfLiteParser/Version.hpp b/include/armnnTfLiteParser/Version.hpp
index 5db527e..43fa436 100644
--- a/include/armnnTfLiteParser/Version.hpp
+++ b/include/armnnTfLiteParser/Version.hpp
@@ -14,7 +14,7 @@
 
 // TfLiteParser version components
 #define TFLITE_PARSER_MAJOR_VERSION 24
-#define TFLITE_PARSER_MINOR_VERSION 5
+#define TFLITE_PARSER_MINOR_VERSION 6
 #define TFLITE_PARSER_PATCH_VERSION 0
 
 /// TFLITE_PARSER_VERSION: "X.Y.Z"
diff --git a/python/pyarmnn/README.md b/python/pyarmnn/README.md
index 547a868..5e8ceb4 100644
--- a/python/pyarmnn/README.md
+++ b/python/pyarmnn/README.md
@@ -91,14 +91,14 @@
 ```bash
 $ python setup.py sdist
 ```
-As the result you will get `./dist/pyarmnn-30.0.0.tar.gz` file. As you can see it is platform independent.
+As the result you will get `./dist/pyarmnn-31.0.0.tar.gz` file. As you can see it is platform independent.
 
 ##### 5. Build the binary package
 
 ```bash
 $ python setup.py bdist_wheel
 ```
-As the result you will get something like `./dist/pyarmnn-30.0.0-cp36-cp36m-linux_x86_64.whl` file. As you can see it
+As the result you will get something like `./dist/pyarmnn-31.0.0-cp36-cp36m-linux_x86_64.whl` file. As you can see it
  is platform dependent.
 
 # PyArmNN installation
@@ -107,8 +107,8 @@
 
 Binary package is platform dependent, the name of the package will indicate the platform it was built for, e.g.:
 
-* Linux x86 64bit machine: pyarmnn-30.0.0-cp36-cp36m-*linux_x86_64*.whl
-* Linux Aarch 64 bit machine: pyarmnn-30.0.0-cp36-cp36m-*linux_aarch64*.whl
+* Linux x86 64bit machine: pyarmnn-31.0.0-cp36-cp36m-*linux_x86_64*.whl
+* Linux Aarch 64 bit machine: pyarmnn-31.0.0-cp36-cp36m-*linux_aarch64*.whl
 
 The source package is platform independent but installation involves compilation of Arm NN python extension. You will need to have g++ compatible with C++ 14 standard and a python development library installed on the build machine.
 
@@ -126,7 +126,7 @@
 ```
 Install PyArmNN from binary by pointing to the wheel file:
 ```bash
-$ pip install /path/to/pyarmnn-30.0.0-cp36-cp36m-linux_aarch64.whl
+$ pip install /path/to/pyarmnn-31.0.0-cp36-cp36m-linux_aarch64.whl
 ```
 
 ## Installing from source package
@@ -145,7 +145,7 @@
 
 Install PyArmNN as follows:
 ```bash
-$ pip install /path/to/pyarmnn-30.0.0.tar.gz
+$ pip install /path/to/pyarmnn-31.0.0.tar.gz
 ```
 
 If PyArmNN installation script fails to find Arm NN libraries it will raise an error like this
@@ -159,7 +159,7 @@
 You can also verify it by running the following and getting output similar to below:
 ```bash
 $ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
 ```
 
 # PyArmNN API overview
diff --git a/python/pyarmnn/examples/image_classification/README.md b/python/pyarmnn/examples/image_classification/README.md
index a360f01..04718e2 100644
--- a/python/pyarmnn/examples/image_classification/README.md
+++ b/python/pyarmnn/examples/image_classification/README.md
@@ -20,7 +20,7 @@
 You can also verify it by running the following and getting output similar to below:

 ```bash

 $ python -c "import pyarmnn as ann;print(ann.GetVersion())"

-'30.0.0'

+'31.0.0'

 ```

 

 ##### Dependencies

diff --git a/python/pyarmnn/examples/keyword_spotting/README.md b/python/pyarmnn/examples/keyword_spotting/README.md
index 1c1deaf..98158e6 100644
--- a/python/pyarmnn/examples/keyword_spotting/README.md
+++ b/python/pyarmnn/examples/keyword_spotting/README.md
@@ -18,7 +18,7 @@
 
 ```bash
 $ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
 ```
 
 ### Dependencies
diff --git a/python/pyarmnn/examples/object_detection/README.md b/python/pyarmnn/examples/object_detection/README.md
index 215cf77..73bafb6 100644
--- a/python/pyarmnn/examples/object_detection/README.md
+++ b/python/pyarmnn/examples/object_detection/README.md
@@ -54,7 +54,7 @@
 You can also verify it by running the following and getting output similar to below:
 ```bash
 $ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
 ```
 
 ##### Dependencies
diff --git a/python/pyarmnn/examples/speech_recognition/README.md b/python/pyarmnn/examples/speech_recognition/README.md
index d5fee8a..e442aad 100644
--- a/python/pyarmnn/examples/speech_recognition/README.md
+++ b/python/pyarmnn/examples/speech_recognition/README.md
@@ -18,7 +18,7 @@
 
 ```bash
 $ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
 ```
 
 ### Dependencies
diff --git a/python/pyarmnn/src/pyarmnn/_version.py b/python/pyarmnn/src/pyarmnn/_version.py
index d1b1ca2..d68a893 100644
--- a/python/pyarmnn/src/pyarmnn/_version.py
+++ b/python/pyarmnn/src/pyarmnn/_version.py
@@ -3,7 +3,7 @@
 # SPDX-License-Identifier: MIT
 import os
 
-version_info = (30, 0, 0)
+version_info = (31, 0, 0)
 
 __dev_version_env = os.getenv("PYARMNN_DEV_VER", "")
 
@@ -24,7 +24,7 @@
     """Compares expected Arm NN version and Arm NN version used to build the package.
 
     Args:
-        installed_armnn_version (str): Arm NN version used to generate the package (e.g. 30.0.0)
+        installed_armnn_version (str): Arm NN version used to generate the package (e.g. 31.0.0)
         expected_armnn_version (str): Expected Arm NN version
 
     Returns:
diff --git a/python/pyarmnn/test/test_setup.py b/python/pyarmnn/test/test_setup.py
index 27feda2..ada96cc 100644
--- a/python/pyarmnn/test/test_setup.py
+++ b/python/pyarmnn/test/test_setup.py
@@ -87,15 +87,15 @@
 
 
 def test_armnn_version():
-    check_armnn_version('30.0.0', '30.0.0')
+    check_armnn_version('31.0.0', '31.0.0')
 
 
 def test_incorrect_armnn_version():
     with pytest.raises(AssertionError) as err:
-        check_armnn_version('30.0.0', '30.1.0')
+        check_armnn_version('31.0.0', '31.1.0')
 
-    assert 'Expected ArmNN version is 30.1.0 but installed ArmNN version is 30.0.0' in str(err.value)
+    assert 'Expected ArmNN version is 31.1.0 but installed ArmNN version is 31.0.0' in str(err.value)
 
 
 def test_armnn_version_patch_does_not_matter():
-    check_armnn_version('30.0.0', '30.0.1')
+    check_armnn_version('31.0.0', '31.0.1')
diff --git a/python/pyarmnn/test/test_version.py b/python/pyarmnn/test/test_version.py
index 83606ab..f68adff 100644
--- a/python/pyarmnn/test/test_version.py
+++ b/python/pyarmnn/test/test_version.py
@@ -18,7 +18,7 @@
 
     importlib.reload(v)
 
-    assert "30.0.0.dev1" == v.__version__
+    assert "31.0.0.dev1" == v.__version__
 
     del os.environ["PYARMNN_DEV_VER"]
     del v
@@ -30,7 +30,7 @@
 
     importlib.reload(v)
 
-    assert "30.0.0" == v.__arm_ml_version__
+    assert "31.0.0" == v.__arm_ml_version__
 
     del os.environ["PYARMNN_DEV_VER"]
     del v
diff --git a/samples/ObjectDetection/Readme.md b/samples/ObjectDetection/Readme.md
index bd84e26..169546e 100644
--- a/samples/ObjectDetection/Readme.md
+++ b/samples/ObjectDetection/Readme.md
@@ -253,8 +253,8 @@
 The full list of libs after cross-compilation to copy on your board:
 ```
 libarmnn.so
-libarmnn.so.30
-libarmnn.so.30.0
+libarmnn.so.31
+libarmnn.so.31.0
 For Arm NN public C++ API mode:
 libarmnnTfLiteParser.so
 libarmnnTfLiteParser.so.24.4
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index f957627..226d121 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #include "armnn/Descriptors.hpp"
@@ -461,80 +461,79 @@
     const TensorShape& tensorXShape,
     const TensorShape& tensorYShape)
 {
-    // May refactor to just work on one input per call - makes it less confusing and also
-    // allows more flexibility (i.e. in Layer output shape inference)
-
-    auto xNumDims = tensorXShape.GetNumDimensions();
-    auto yNumDims = tensorYShape.GetNumDimensions();
-
-    std::pair<unsigned int, unsigned int> xAxes = { xNumDims-2, xNumDims-1 };
-    std::pair<unsigned int, unsigned int> yAxes = { yNumDims-2, yNumDims-1 };
-
-    if(desc.m_DataLayoutX.has_value())
-    {
-        switch(desc.m_DataLayoutX.value())
-        {
-            case DataLayout::NDHWC:
-            case DataLayout::NHWC:
-                xAxes.first -= 1;
-                xAxes.second -= 1;
-                break;
-            case DataLayout::NCDHW:
-            case DataLayout::NCHW:
-            default:
-                break;
-        }
-    }
-
-    if(desc.m_DataLayoutY.has_value())
-    {
-        switch(desc.m_DataLayoutY.value())
-        {
-            case DataLayout::NDHWC:
-            case DataLayout::NHWC:
-                yAxes.first -= 1;
-                yAxes.second -= 1;
-                break;
-            case DataLayout::NCDHW:
-            case DataLayout::NCHW:
-            default:
-                break;
-        }
-    }
-
-    return { xAxes, yAxes};
+    return { GetAxesToMul(desc.m_DataLayoutX, tensorXShape),
+             GetAxesToMul(desc.m_DataLayoutY, tensorYShape) };
 }
-
 std::pair<std::vector<unsigned int>, std::vector<unsigned int>> BatchMatMulDescriptor::GetAxesNotMul(
     const BatchMatMulDescriptor& desc,
     const TensorShape& inputXShape,
     const TensorShape& inputYShape)
 {
-    // May refactor to just work on one input per call - makes it less confusing and also
-    // allows more flexibility (i.e. in Layer output shape inference)
-    auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(desc, inputXShape, inputYShape);
+    return { GetAxesNotMul(desc.m_DataLayoutX, inputXShape),
+             GetAxesNotMul(desc.m_DataLayoutY, inputYShape) };
+}
 
-    std::vector<unsigned int> axesXNotMul;
-    std::vector<unsigned int> axesYNotMul;
-
-    for(unsigned int i = 0; i < inputXShape.GetNumDimensions(); i++)
+std::pair<unsigned int, unsigned int> BatchMatMulDescriptor::GetAxesToMul(
+    DataLayout dataLayout,
+    const TensorShape& tensorShape)
+{
+    auto numDims = tensorShape.GetNumDimensions();
+    std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
+    switch(dataLayout)
     {
-        if(i == axesToMul.first.first || i == axesToMul.first.second)
+        case DataLayout::NDHWC:
+        case DataLayout::NHWC:
+            axes.first -= 1;
+            axes.second -= 1;
+            break;
+        case DataLayout::NCDHW:
+        case DataLayout::NCHW:
+        default:
+            break;
+    }
+    return axes;
+}
+
+std::vector<unsigned int> BatchMatMulDescriptor::GetAxesNotMul(
+    DataLayout dataLayout,
+    const TensorShape& tensorShape)
+{
+    auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
+    std::vector<unsigned int> axesNotMul;
+    for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
+    {
+        if(i == axesToMul.first || i == axesToMul.second)
         {
             continue;
         }
-        axesXNotMul.push_back(i);
+        axesNotMul.push_back(i);
     }
-    for(unsigned int i = 0; i < inputYShape.GetNumDimensions(); i++)
-    {
-        if(i == axesToMul.second.first || i == axesToMul.second.second)
-        {
-            continue;
-        }
-        axesYNotMul.push_back(i);
-    }
+    return axesNotMul;
+}
 
-    return { axesXNotMul, axesYNotMul };
+PermutationVector BatchMatMulDescriptor::GetPermuteVec(
+    DataLayout dataLayout,
+    const TensorShape& tensorShape)
+{
+    std::vector<unsigned int> vec;
+    auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
+    for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
+    {
+        if(i == axesToMul.first)
+        {
+            vec.push_back(i+1);
+        }
+        else if(i == axesToMul.second)
+        {
+            vec.push_back(i-1);
+        }
+        else
+        {
+            vec.push_back(i);
+        }
+    }
+    return PermutationVector(vec.data(),
+                             static_cast<unsigned int>(vec.size()));
 }
 
 }
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
index 501de2d..acd089a 100644
--- a/src/armnn/layers/BatchMatMulLayer.cpp
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -5,6 +5,7 @@
 #include "BatchMatMulLayer.hpp"
 
 #include <armnn/backends/WorkloadFactory.hpp>
+#include <armnnUtils/Permute.hpp>
 #include "layers/LayerCloneBase.hpp"
 
 namespace armnn
@@ -36,12 +37,24 @@
     TensorShape inputXShape = inputShapes[0];
     TensorShape inputYShape = inputShapes[1];
 
-    // Note: Take into account what pre-adjoint or pre-transposing will do to the inferred output shape
+    // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
+    if(m_Param.m_TransposeX)
+    {
+        auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
+                                                               inputXShape);
+        inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
+    }
+    if(m_Param.m_TransposeY)
+    {
+        auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
+                                                               inputYShape);
+        inputYShape = armnnUtils::Permuted(inputYShape, permuteVec);
+    }
 
     TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
-                               inputXShape:inputYShape;
+                               inputXShape : inputYShape;
     TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
-                                inputYShape:inputXShape;
+                                inputYShape : inputXShape;
 
     unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();
 
@@ -49,10 +62,10 @@
 
     std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);
 
-    auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Param, inputXShape, inputYShape);
-    const auto& longerAxesToMul = (axesToMul.first.first >= axesToMul.second.first &&
-                             axesToMul.first.second >= axesToMul.second.second) ?
-                                 axesToMul.first : axesToMul.second;
+    const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
+                                        m_Param.m_DataLayoutX : m_Param.m_DataLayoutY;
+    auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout,
+                                                               longerInput);
 
     for (unsigned int i = 0; i < outputNumDimensions; ++i)
     {
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9a4c60f..f4afbd9 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -8,6 +8,7 @@
 #include <armnn/backends/WorkloadInfo.hpp>
 #include <armnnUtils/DataLayoutIndexed.hpp>
 #include <armnnUtils/TensorUtils.hpp>
+#include <armnnUtils/Permute.hpp>
 #include <armnn/utility/NumericCast.hpp>
 #include <armnn/Logging.hpp>
 
@@ -4154,9 +4155,10 @@
     // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
     // axes N and I must be the same size
 
-    const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
-    const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
-    const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+    const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
+    const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
+    const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
+    // Output info has already been inferred
 
     std::vector<DataType> supportedTypes =
     {
@@ -4168,108 +4170,127 @@
         DataType::QSymmS16
     };
 
-    ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
-    ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
-    ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+    ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
+    ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
+    ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
 
-    if ((inputTensorXInfo.GetNumDimensions() < 2) ||
-        (inputTensorYInfo.GetNumDimensions() < 2))
+    if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
+        (inputYInfoBeforeParams.GetNumDimensions() < 2))
     {
         throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
     }
 
-    if(m_Parameters.m_DataLayoutX.has_value())
+    TensorInfo inputXInfoAfterParams;
+    TensorInfo inputYInfoAfterParams;
+
+    if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
+       (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
     {
-        switch(m_Parameters.m_DataLayoutX.value())
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameters - Transpose and Adjoint "
+            "cannot both be true for a given input tensor.");
+    }
+    if(m_Parameters.m_TransposeX)
+    {
+        inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
+                                                     BatchMatMulDescriptor::GetPermuteVec(
+                                                         m_Parameters.m_DataLayoutX,
+                                                         inputXInfoBeforeParams.GetShape()));
+    }
+    else if(m_Parameters.m_AdjointX)
+    {
+        auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+                                                             inputXInfoBeforeParams.GetShape());
+        if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
+           inputXInfoBeforeParams.GetShape()[axesToMul.second])
         {
-            case DataLayout::NCHW:
-            case DataLayout::NHWC:
-                if(inputTensorXInfo.GetNumDimensions() != 4)
-                {
-                    throw InvalidArgumentException(descriptorName +
-                        ": Input tensor X does not have the correct "
-                        "number of dimensions for the Data Layout that it has been assigned.");
-                }
-                break;
-            case DataLayout::NCDHW:
-            case DataLayout::NDHWC:
-                if(inputTensorXInfo.GetNumDimensions() != 5)
-                {
-                    throw InvalidArgumentException(descriptorName +
-                        ": Input tensor X does not have the correct "
-                        "number of dimensions for the Data Layout that it has been assigned.");
-                }
-                break;
-            default:
-                break;
+            throw InvalidArgumentException(descriptorName +
+                ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
         }
+        // Shape remains the same as it's square
+        inputXInfoAfterParams = inputXInfoBeforeParams;
+    }
+    else
+    {
+        inputXInfoAfterParams = inputXInfoBeforeParams;
     }
 
-    if(m_Parameters.m_DataLayoutY.has_value())
+    if(m_Parameters.m_TransposeY)
     {
-        switch(m_Parameters.m_DataLayoutY.value())
+        inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
+                                                     BatchMatMulDescriptor::GetPermuteVec(
+                                                         m_Parameters.m_DataLayoutY,
+                                                         inputYInfoBeforeParams.GetShape()));
+    }
+    else if(m_Parameters.m_AdjointY)
+    {
+        auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+                                                             inputYInfoBeforeParams.GetShape());
+        if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
+           inputYInfoBeforeParams.GetShape()[axesToMul.second])
         {
-            case DataLayout::NCHW:
-            case DataLayout::NHWC:
-                if(inputTensorYInfo.GetNumDimensions() != 4)
-                {
-                    throw InvalidArgumentException(descriptorName +
-                        ": Input tensor Y does not have the correct "
-                        "number of dimensions for the Data Layout that it has been assigned.");
-                }
-                break;
-            case DataLayout::NCDHW:
-            case DataLayout::NDHWC:
-                if(inputTensorYInfo.GetNumDimensions() != 5)
-                {
-                    throw InvalidArgumentException(descriptorName +
-                        ": Input tensor Y does not have the correct "
-                        "number of dimensions for the Data Layout that it has been assigned.");
-                }
-                break;
-            default:
-                break;
+            throw InvalidArgumentException(descriptorName +
+                ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
         }
+        // Shape remains the same as it's square
+        inputYInfoAfterParams = inputYInfoBeforeParams;
+    }
+    else
+    {
+        inputYInfoAfterParams = inputYInfoBeforeParams;
     }
 
-    auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
-                                                         inputTensorXInfo.GetShape(),
-                                                         inputTensorYInfo.GetShape());
+    switch(m_Parameters.m_DataLayoutX)
+    {
+        case DataLayout::NCDHW:
+        case DataLayout::NDHWC:
+            if(inputXInfoAfterParams.GetNumDimensions() < 3)
+            {
+                throw InvalidArgumentException(descriptorName +
+                    ": Input tensor X does not have the correct "
+                    "number of dimensions for the Data Layout that it has been assigned.");
+            }
+            break;
+        case DataLayout::NCHW:
+        case DataLayout::NHWC:
+        default:
+            break;
+    }
 
-    if(inputTensorXInfo.GetShape()[axesToMul.first.second]
-       != inputTensorYInfo.GetShape()[axesToMul.second.first])
+    switch(m_Parameters.m_DataLayoutY)
+    {
+        case DataLayout::NCDHW:
+        case DataLayout::NDHWC:
+            if(inputYInfoAfterParams.GetNumDimensions() < 3)
+            {
+                throw InvalidArgumentException(descriptorName +
+                    ": Input tensor Y does not have the correct "
+                    "number of dimensions for the Data Layout that it has been assigned.");
+            }
+            break;
+        case DataLayout::NCHW:
+        case DataLayout::NHWC:
+        default:
+            break;
+    }
+
+    auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+        inputXInfoAfterParams.GetShape());
+    auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+        inputXInfoBeforeParams.GetShape());
+
+    if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
+       != inputYInfoAfterParams.GetShape()[axesYToMul.first])
     {
         throw InvalidArgumentException(descriptorName +
             ": The final axis of input tensor X must be the same size as "
             "the second last axis of input tensor Y.");
     }
 
-    auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
-                                                           inputTensorXInfo.GetShape(),
-                                                           inputTensorYInfo.GetShape());
-
     {   // Separate scope so we don't pollute the rest of the scope with our temp variables
         // e.g. NHWC isnt compatible with NCHW as of now
-        DataLayout xLayout;
-        DataLayout yLayout;
-
-        if(m_Parameters.m_DataLayoutX == EmptyOptional())
-        {
-            xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
-        }
-        else
-        {
-            xLayout = m_Parameters.m_DataLayoutX.value();
-        }
-
-        if(m_Parameters.m_DataLayoutY == EmptyOptional())
-        {
-            yLayout = DataLayout::NCHW;
-        }
-        else
-        {
-            yLayout = m_Parameters.m_DataLayoutY.value();
-        }
+        DataLayout xLayout = m_Parameters.m_DataLayoutX;
+        DataLayout yLayout = m_Parameters.m_DataLayoutY;
 
         if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
         {
@@ -4290,8 +4311,8 @@
     }
 
     // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
-    unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
-                                                inputTensorYInfo.GetNumDimensions());
+    unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
+                                                inputYInfoAfterParams.GetNumDimensions());
     if(outputTensorDimSize-2 > 0)
     {
         TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
@@ -4312,12 +4333,17 @@
 
             for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
             {
-                ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+                ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
             }
         };
 
-        doAxisExtension(axesNotMul.first, tiXNotMul);
-        doAxisExtension(axesNotMul.second, tiYNotMul);
+        auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
+                                                                inputXInfoAfterParams.GetShape());
+        auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
+                                                                inputYInfoAfterParams.GetShape());
+
+        doAxisExtension(axesXNotMul, tiXNotMul);
+        doAxisExtension(axesYNotMul, tiYNotMul);
 
         for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
         {
@@ -4332,42 +4358,6 @@
                                            "input_X",
                                            "input_Y");
     }
-
-    // Also check descriptor parameter validity
-    // This will eventually be moved to the start of the function as explained below
-    if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
-        (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
-    {
-        throw InvalidArgumentException(descriptorName +
-            ": Invalid descriptor parameters - Transpose and Adjoint "
-            "vectors cannot both be true for a given input tensor.");
-    }
-
-    if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
-    {
-        throw InvalidArgumentException(descriptorName +
-            ": Invalid descriptor parameter - Transpose X vector must be "
-            "the same size as tensor input X's dimensionality.");
-    }
-    if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
-    {
-        throw InvalidArgumentException(descriptorName +
-            ": Invalid descriptor parameter - Adjoint X vector must be "
-            "the same size as tensor input X's dimensionality.");
-    }
-    if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
-    {
-        throw InvalidArgumentException(descriptorName +
-            ": Invalid descriptor parameter - Transpose Y vector must be "
-            "the same size as tensor input Y's dimensionality.");
-    }
-    if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
-    {
-        throw InvalidArgumentException(descriptorName +
-            ": Invalid descriptor parameter - Adjoint Y vector must be "
-            "the same size as tensor input Y's dimensionality.");
-    }
-    // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
 }
 
 
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
index 41add6e..6fcc35a 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
@@ -191,7 +191,7 @@
     std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
         19, 22,
         43, 50
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
                                                 memoryManager,
@@ -247,9 +247,7 @@
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     const armnn::ITensorHandleFactory& tensorHandleFactory)
 {
-    auto descriptor = armnn::BatchMatMulDescriptor(
-        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW),
-        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW));
+    auto descriptor = armnn::BatchMatMulDescriptor(); // Default arbitrary layout is treated the same as NCHW
 
     float qScale = 0.0f;
     int32_t qOffset = 0;
@@ -282,7 +280,7 @@
     std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
         19, 22,
         43, 50
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
                                                 memoryManager,
@@ -338,9 +336,12 @@
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     const armnn::ITensorHandleFactory& tensorHandleFactory)
 {
-    auto descriptor = armnn::BatchMatMulDescriptor(
-        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC),
-        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+    auto descriptor = armnn::BatchMatMulDescriptor(false,
+                                                   false,
+                                                   false,
+                                                   false,
+                                                   armnn::DataLayout::NHWC,
+                                                   armnn::DataLayout::NHWC);
 
     float qScale = 0.0f;
     int32_t qOffset = 0;
@@ -373,7 +374,7 @@
     std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
         19, 22,
         43, 50
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
                                                 memoryManager,
@@ -471,7 +472,7 @@
 
         267, 286,
         323, 346
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
                                                 memoryManager,
@@ -566,7 +567,7 @@
 
         267, 286,
         323, 346
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
                                                 memoryManager,
@@ -661,7 +662,7 @@
 
         267, 286,
         323, 346
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
                                                 memoryManager,
@@ -717,9 +718,12 @@
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     const armnn::ITensorHandleFactory& tensorHandleFactory)
 {
-    auto descriptor = armnn::BatchMatMulDescriptor(
-        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NDHWC),
-        armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+    auto descriptor = armnn::BatchMatMulDescriptor(false,
+                                                   false,
+                                                   false,
+                                                   false,
+                                                   armnn::DataLayout::NDHWC,
+                                                   armnn::DataLayout::NHWC);
 
     float qScale = 0.0f;
     int32_t qOffset = 0;
@@ -761,7 +765,7 @@
 
        34, 1079,
        46, 1167
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 5>(workloadFactory,
                                                 memoryManager,
@@ -959,7 +963,7 @@
         88, 100, 142, 106,
         39, 61, 78, 56,
         72, 52, 98, 70
-    },qScale, qOffset);
+    }, qScale, qOffset);
 
     return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
                                                 memoryManager,
@@ -1007,4 +1011,330 @@
 BatchMatMul3DNonSquareTest<armnn::DataType::QSymmS16>(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(true,
+                                                   false,
+                                                   false,
+                                                   false);
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({2,3}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({2,3}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+        1, 2, 3,
+        4, 5, 6
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        7, 8, 9,
+        10, 11, 12
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        47, 52, 57,
+        64, 71, 78,
+        81, 90, 99
+    }, qScale, qOffset);
+
+    return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(false,
+                                                   false,
+                                                   true,
+                                                   false);
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({3,3}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({3,3}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+        3, 1, 1,
+        1, 3, -1,
+        2, 4, 1
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        1, 0, 0,
+        0, 1, 0,
+        0, 0, 1
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        7, 3, -4,
+        -3, 1, 4,
+        -2, -10, 8
+    }, qScale, qOffset);
+
+    switch (ArmnnType)
+    {
+        case armnn::DataType::QAsymmU8:
+            outputExpected = armnnUtils::QuantizedVector<T>({
+                3, 3, 0,
+                0, 1, 1,
+                0, 0, 8
+            }, qScale, qOffset);
+            break;
+        default:
+            break;
+    }
+
+    return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    auto descriptor = armnn::BatchMatMulDescriptor(false,
+                                                   true,
+                                                   true,
+                                                   false,
+                                                   armnn::DataLayout::NHWC,
+                                                   armnn::DataLayout::NHWC);
+
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
+
+    switch(ArmnnType)
+    {
+        case armnn::DataType::QAsymmS8:
+        case armnn::DataType::QAsymmU8:
+        case armnn::DataType::QSymmS16:
+            qScale = 1.0f;
+            break;
+        default:
+            break;
+    }
+
+    armnn::TensorInfo inputXInfo({1,4,4,2}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo inputYInfo({2,2,4,1}, ArmnnType, qScale, qOffset);
+    armnn::TensorInfo outputInfo({2,4,2,2}, ArmnnType, qScale, qOffset);
+
+    std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+       1, -3, 1, 4, 4, 9, 1, 2,
+       2, 4, 2, 2, 10, 7, 6, -5,
+       3, 8, 9, 9, 21, 1, 17, 7,
+       5, 11, 11, 8, 29, 3, 23, 6
+    }, qScale, qOffset);
+
+    std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+        1, 2, 3, 4,
+        5, 6, 7, 8,
+
+        9, 10, 11, 12,
+        13, 14, 15, 16
+    }, qScale, qOffset);
+
+    std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+        28, 625, 140, 585,
+        8, 110, -8, 1662,
+        -24, 401, -120, 921,
+        12, 131, 108, -501,
+
+        252, 545, 364, 505,
+        -24, 3214, -40, 4766,
+        -216, 1441, -312, 1961,
+        204, -1133, 300, -1765
+    }, qScale, qOffset);
+
+    switch (ArmnnType)
+    {
+        case armnn::DataType::QAsymmU8:
+            outputExpected = armnnUtils::QuantizedVector<T>({
+                28, 80, 140, 80,
+                8, 45, 0, 255,
+                0, 18, 0, 18,
+                12, 0, 108, 0,
+
+                252, 80, 255, 80,
+                0, 255, 0, 255,
+                0, 18, 0, 18,
+                204, 0, 255, 0
+            }, qScale, qOffset);
+            break;
+        default:
+            break;
+    }
+
+    return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+                                                memoryManager,
+                                                tensorHandleFactory,
+                                                descriptor,
+                                                inputX,
+                                                inputY,
+                                                outputExpected,
+                                                inputXInfo,
+                                                inputYInfo,
+                                                outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::BFloat16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float32>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmS8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmU8>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QSymmS16>(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     const armnn::ITensorHandleFactory& tensorHandleFactory);
\ No newline at end of file
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
index 9e21396..0b261fb 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
@@ -82,4 +82,22 @@
 LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     const armnn::ITensorHandleFactory& tensorHandleFactory);
\ No newline at end of file
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 593dc78..ae40333 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1133,6 +1133,27 @@
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest<DataType::QAsymmU8>);
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest<DataType::QSymmS16>);
 
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleBFloat16, BatchMatMul2DTranspSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat32, BatchMatMul2DTranspSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat16, BatchMatMul2DTranspSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmS8, BatchMatMul2DTranspSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmU8, BatchMatMul2DTranspSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQASymmS16,BatchMatMul2DTranspSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleBFloat16, BatchMatMul2DAdjointSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat32, BatchMatMul2DAdjointSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat16, BatchMatMul2DAdjointSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmS8, BatchMatMul2DAdjointSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmU8, BatchMatMul2DAdjointSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQASymmS16,BatchMatMul2DAdjointSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsBFloat16, BatchMatMulNHWCParamsTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat32, BatchMatMulNHWCParamsTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat16, BatchMatMulNHWCParamsTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmS8, BatchMatMulNHWCParamsTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmU8, BatchMatMulNHWCParamsTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQASymmS16, BatchMatMulNHWCParamsTest<DataType::QSymmS16>);
+
 // Batch Norm
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test)
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest)
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
index 6693f15..c592b3b 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.cpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -7,46 +7,53 @@
 
 #include <armnn/backends/WorkloadData.hpp>
 #include <armnn/Logging.hpp>
+#include <armnnUtils/Permute.hpp>
 
 namespace armnn
 {
 
-void BatchMatMul::BatchMatMulImpl()
+BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params,
+                         const TensorInfo& inputXInfo,
+                         const TensorInfo& inputYInfo,
+                         const TensorInfo& outputInfo,
+                         Decoder<float>& inputXDecoder,
+                         Decoder<float>& inputYDecoder,
+                         Encoder<float>& outputEncoder)
+    : params(params),
+      inputXInfo(inputXInfo),
+      inputYInfo(inputYInfo),
+      outputInfo(outputInfo),
+      inputXDecoder(inputXDecoder),
+      inputYDecoder(inputYDecoder),
+      outputEncoder(outputEncoder)
 {
-    inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
-    inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
+    inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
+    inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
     // At this point, we don't touch the input decoders - just the resultant vectors
 
-    // Pre-transpose and pre-adjoint if their vectors aren't empty
-    // and also DataLayouts which may change with permutations/adjoints
+    ApplyParams();
 
-    // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
-
-    auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
-    RecurseBMM(idx, 0);
+    ApplyBatchMatMul();
 }
 
-void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
+void BatchMatMul::ApplyBatchMatMul()
 {
-    // We're working off of the indexes of the output tensor (the max possible shape)
+    auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
+                                                          inputXInfo.GetShape());
+    auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
+                                                          inputYInfo.GetShape());
+    AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
 
-    if(!(curDim < outputInfo.GetNumDimensions()))
+    unsigned int inputXColDim = axesXToMul.second;
+    unsigned int inputYRowDim = axesYToMul.first;
+
+    unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+
+    auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
     {
-        // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
-
-        auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
-                                                             inputXInfo.GetShape(),
-                                                             inputYInfo.GetShape());
-        AdjustAxesToMulForUnequalRanks(axesToMul);
-
-        unsigned int inputXColDim = axesToMul.first.second;
-        unsigned int inputYRowDim = axesToMul.second.first;
-
-        unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
-
         float sum = 0.0f;
 
-        // You could also use inputXColSize
+        // InputYRowSize is synonymous with inputXColSize
         for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
             auto xIdx = curIdx;
             xIdx[inputXColDim] = inputYRowIdx;
@@ -54,24 +61,271 @@
             auto yIdx = curIdx;
             yIdx[inputYRowDim] = inputYRowIdx;
 
-            sum += (GetValueAt(DataSlot::InputX, xIdx)
-                  * GetValueAt(DataSlot::InputY, yIdx));
+            sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
         }
 
         SetValueAt(sum, DataSlot::Output, curIdx);
+    };
 
-        return;
-    }
+    auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
+    RecurseTensor(outputInfo,
+                  batchMatMulOperation,
+                  startIdx,
+                  0);
+}
 
-    for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
+void BatchMatMul::ApplyParams()
+{
+    if(params.m_TransposeX)
     {
-        curIdx[curDim] = i;
-        RecurseBMM(curIdx, curDim+1);
+        Transpose(DataSlot::InputX);
+    }
+    else if(params.m_AdjointX)
+    {
+        Adjoint(DataSlot::InputX);
+    }
+    if(params.m_TransposeY)
+    {
+        Transpose(DataSlot::InputY);
+    }
+    else if(params.m_AdjointY)
+    {
+        Adjoint(DataSlot::InputY);
     }
 }
 
-void BatchMatMul::AdjustAxesToMulForUnequalRanks(
-    std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
+void BatchMatMul::Transpose(DataSlot type)
+{
+    // AKA the permute of the tensor
+    // This modifies the tensor's info.
+
+    switch(type)
+    {
+        case DataSlot::InputX:
+        {
+            auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
+                                                                   inputXInfo.GetShape());
+            inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
+            std::vector<float> temp(inputXData.size());
+            armnnUtils::Permute(inputXInfo.GetShape(),
+                                permuteVec,
+                                inputXData.data(),
+                                temp.data(),
+                                sizeof(float));
+            inputXData = temp;
+            break;
+        }
+        case DataSlot::InputY:
+        {
+            auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
+                                                                   inputYInfo.GetShape());
+            inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
+            std::vector<float> temp(inputYData.size());
+            armnnUtils::Permute(inputYInfo.GetShape(),
+                                permuteVec,
+                                inputYData.data(),
+                                temp.data(),
+                                sizeof(float));
+            inputYData = temp;
+            break;
+        }
+        case DataSlot::Output: // We needn't transpose the output tensor
+        default:
+            break;
+    }
+}
+
+void BatchMatMul::Adjoint(DataSlot type)
+{
+    // Finding the adjoint of a square matrix:
+    // Calculate the cofactor of each element (using Gauss elimination here)
+    // Apply a transpose to it (this also modifies the tensor's info)
+
+    TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
+    const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
+    const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
+
+    ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
+    // We grab a copy of the tensor data to prevent overwriting
+    std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
+
+    // The sub-matrix is the resultant matrix when the row and column of the current index is removed
+    unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
+    std::vector<std::vector<float>> subMat(subMatAxisSize,
+                                           std::vector<float>(subMatAxisSize));
+
+    // Lambdas for each sub-step of the cofactor operation
+    auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
+    {
+        float diff = std::fabs(a-b);
+        float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
+        return (diff <= bound) || (diff < std::numeric_limits<float>::min());
+    };
+
+    float swapMultiplier = std::numeric_limits<float>::max();
+    auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
+    {
+        // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
+        for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
+        {
+            float tmp = subMat[rowIdxA][colIdx];
+            subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
+            subMat[rowIdxB][colIdx] = tmp;
+        }
+        swapMultiplier *= -1.0f;
+    };
+
+    auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
+    {
+        unsigned int result = std::numeric_limits<unsigned int>::max();
+
+        // The original diagonal has been checked and is invalid
+        for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
+        {
+            if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
+            {
+                result = rowIdx;
+                break;
+            }
+        }
+        return result;
+    };
+
+    auto eliminate = [&](const float& pivot, unsigned int pivotPos)
+    {
+        for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
+        {
+            float multiplierNumerator = subMat[rowIdx][pivotPos];
+            if(almostEquals(multiplierNumerator, 0.0f))
+            {
+                continue;
+            }
+            float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
+                                                            // Hence the almostEquals usage to counteract this
+            for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
+            {
+                // We start at col=pivotPos as we have assumed that all elements
+                // to our left have been eliminated to zero already
+
+                // We subtract based on the element directly above us in our pivot row
+                subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
+            }
+        }
+    };
+
+    auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
+    {
+        auto row = curIdx[axesToAdjoint.first];
+        auto col = curIdx[axesToAdjoint.second];
+
+        float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
+
+        for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
+        {
+            for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
+            {
+                unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
+                unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
+                auto cloneIdx = curIdx;
+                cloneIdx[axesToAdjoint.first] = outerRow;
+                cloneIdx[axesToAdjoint.second] = outerCol;
+                subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
+            }
+        }
+
+        float determinant = 1.0f;
+
+        // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
+        switch(subMatAxisSize)
+        {
+            case 0:
+            {
+                determinant = GetValueAt(type, curIdx, inputDataClone);
+                break;
+            }
+            case 1:
+            {
+                // If the resultant sub-matrix is just one element - that's the determinant
+                determinant = subMat[0][0];
+                break;
+            }
+            case 2:
+            {
+                // For a 2x2 sub-matrix, the determinant is just a*d-b*c
+                determinant = subMat[0][0] * subMat[1][1] -
+                              subMat[0][1] * subMat[1][0];
+                break;
+            }
+            default:
+            {
+                // Gaussian elimination to find the determinant of this sub-matrix
+                swapMultiplier = 1.0f;
+                // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
+                // nearest non-zero down within the column
+                for(unsigned int pivotRow = 0, pivotCol = 0;
+                    pivotRow < subMatAxisSize;
+                    pivotRow++, pivotCol++)
+                {
+                    float& pivot = subMat[pivotRow][pivotCol];
+
+                    if(almostEquals(pivot, 0.0f))
+                    {
+                        unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
+                        if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
+                        {
+                            // No valid pivot down this column, which means that this pivot remains a zero.
+                            // This results in the determinant for this entire sub-matrix to just be zero.
+                            determinant = 0.0f;
+                            break;
+                        }
+                        swapRows(pivotRow, nextValidPivotRowIdx);
+                    }
+                    determinant *= pivot;
+                    // The actual elimination bit (which will update/propagate to the pivots down the line)
+                    eliminate(pivot, pivotRow); // Synonymous with pivotCol
+                }
+
+                determinant *= swapMultiplier;
+                break;
+            }
+        }
+        float cofactor = minorMultiplier * determinant;
+        SetValueAt(cofactor, type, curIdx);
+    };
+
+    auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
+    RecurseTensor(inputInfo,
+                  cofactorOperation,
+                  startIdx,
+                  0);
+
+    Transpose(type);
+}
+
+void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
+                                const std::function<void(const std::vector<unsigned int>&)>& operation,
+                                std::vector<unsigned int>& curIdx,
+                                unsigned int curDim)
+{
+    if(!(curDim < tensorInfo.GetNumDimensions()))
+    {
+        // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+        operation(curIdx);
+        return;
+    }
+
+    for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
+    {
+        curIdx[curDim] = i;
+        RecurseTensor(tensorInfo,
+                      operation,
+                      curIdx,
+                      curDim + 1);
+    }
+}
+
+void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+                                                 std::pair<unsigned int, unsigned int>& axesYToMul)
 {
     int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
                    static_cast<int>(inputYInfo.GetNumDimensions());
@@ -82,18 +336,18 @@
     else if(rankDiff < 0)
     {
         // Y is the larger one
-        axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
-        axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+        axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+        axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
     }
     else if(rankDiff > 0)
     {
         // X is the larger one
-        axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
-        axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+        axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+        axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
     }
 }
 
-float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
+float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
 {
     // This gets the data from the input vector that we have, Not the decoder
     // But for the output, it is operating on the encoder itself
@@ -101,14 +355,13 @@
     AdjustToSafeIdx(type, idx);
     unsigned int flatIdx = CalcFlatIdx(type, idx);
     float value = 0.0f;
-
     switch(type)
     {
         case DataSlot::InputX:
-            value = inputXData[flatIdx];
+            value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
             break;
         case DataSlot::InputY:
-            value = inputYData[flatIdx];
+            value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
             break;
         case DataSlot::Output:
             outputEncoder[flatIdx];
@@ -124,9 +377,7 @@
 void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
 {
     AdjustToSafeIdx(type, idx);
-
     unsigned int flatIdx = CalcFlatIdx(type, idx);
-
     switch(type)
     {
         case DataSlot::InputX:
@@ -186,9 +437,7 @@
 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
 {
     unsigned int result = idx[idx.size()-1];
-
     unsigned int dimMultiplier = 1;
-
     unsigned int offset;
 
     // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
@@ -215,17 +464,4 @@
     return result;
 }
 
-template <typename T>
-std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
-{
-    std::string res = "{ ";
-    for(auto x : vec)
-    {
-        res += std::to_string(x);
-        res += " ";
-    }
-    res += "}";
-    return res;
-}
-
 } // namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp
index 25b6c85..19971a4 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.hpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp
@@ -15,6 +15,15 @@
 
 class BatchMatMul {
 public:
+    BatchMatMul(const BatchMatMulDescriptor& params,
+                const TensorInfo& inputXInfo,
+                const TensorInfo& inputYInfo,
+                const TensorInfo& outputInfo,
+                Decoder<float>& inputXDecoder,
+                Decoder<float>& inputYDecoder,
+                Encoder<float>& outputEncoder);
+
+private:
     enum DataSlot
     {
         InputX = 0,
@@ -22,47 +31,10 @@
         Output = 2
     };
 
-    BatchMatMul(const BatchMatMulDescriptor& params,
-                const TensorInfo& inputXInfo,
-                const TensorInfo& inputYInfo,
-                const TensorInfo& outputInfo,
-                Decoder<float>& inputXDecoder,
-                Decoder<float>& inputYDecoder,
-                Encoder<float>& outputEncoder)
-        : params(params),
-          inputXInfo(inputXInfo),
-          inputYInfo(inputYInfo),
-          outputInfo(outputInfo),
-          inputXDecoder(inputXDecoder),
-          inputYDecoder(inputYDecoder),
-          outputEncoder(outputEncoder)
-    {}
-
-    void BatchMatMulImpl();
-
-    void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
-
-    // Adjusts it for when input tensors are of unequal rank
-    void AdjustAxesToMulForUnequalRanks(
-        std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
-
-    float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
-
-    void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
-
-    // Takes into account broadcasting
-    void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
-
-    unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
-
-    template <typename T>
-    std::string StringifyVec(const std::vector<T>& vec);
-
-private:
     const BatchMatMulDescriptor& params;
-    const TensorInfo& inputXInfo;
-    const TensorInfo& inputYInfo;
-    const TensorInfo& outputInfo;
+    TensorInfo inputXInfo;
+    TensorInfo inputYInfo;
+    TensorInfo outputInfo;
     Decoder<float>& inputXDecoder;
     Decoder<float>& inputYDecoder;
     Encoder<float>& outputEncoder;
@@ -70,6 +42,31 @@
     std::vector<float> inputXData;
     std::vector<float> inputYData;
 
+    void ApplyBatchMatMul();
+
+    void ApplyParams();
+
+    void Transpose(DataSlot type);
+
+    void Adjoint(DataSlot type);
+
+    void RecurseTensor(const TensorInfo& tensorInfo,
+                       std::function<void(const std::vector<unsigned int>&)> const& operation,
+                       std::vector<unsigned int>& curIdx,
+                       unsigned int curDim);
+
+    // Adjusts it for when input tensors are of unequal rank
+    void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+                                        std::pair<unsigned int, unsigned int>& axesYToMul);
+
+    float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
+
+    void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
+
+    // Takes into account broadcasting
+    void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
+
+    unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
 };
 
 } // namespace armnn
\ No newline at end of file
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
index 388190c..027b93b 100644
--- a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
@@ -51,9 +51,6 @@
                            *inputXDecoder,
                            *inputYDecoder,
                            *outputEncoder);
-
-    bmm.BatchMatMulImpl();
-
 }
 
 } // namespace armnn
\ No newline at end of file