IVGCVSW-6494 Add CpuAcc Batch MatMul Workload Fp32


Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I2def6995f81d33e68f1ea45d8d19a1e6294049b1
diff --git a/delegate/src/test/BatchMatMulTest.cpp b/delegate/src/test/BatchMatMulTest.cpp
index 5469bc8..e5cb976 100644
--- a/delegate/src/test/BatchMatMulTest.cpp
+++ b/delegate/src/test/BatchMatMulTest.cpp
@@ -654,4 +654,20 @@
         }
     }
 
+    TEST_SUITE("BATCH_MATMUL_CpuAccTests")
+    {
+        TEST_CASE("BATCH_MATMUL_Fp32_CpuAccTests")
+        {
+            std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
+            BatchMatMul2DFp32SimpleTest       (backends);
+            BatchMatMul3DFp32SimpleTest       (backends);
+            BatchMatMul4DFp32SimpleTest       (backends);
+            BatchMatMul3DFp32BatchTest        (backends);
+            BatchMatMul3DFp32BroadcastTest    (backends);
+            BatchMatMul3D2DFp32BroadcastTest  (backends);
+            BatchMatMul2DFp32TinyTest         (backends);
+            BatchMatMulNonSquareFp32Test      (backends);
+            BatchMatMul2DFp32SimpleAdjointTest(backends);
+        }
+    }
 }
diff --git a/docs/02_operator_list.dox b/docs/02_operator_list.dox
index 658aa07..3a902c8 100644
--- a/docs/02_operator_list.dox
+++ b/docs/02_operator_list.dox
@@ -293,12 +293,13 @@
   <td>CpuAcc
   <td>
       <ul>
-       <li>N/A
+       <li>All
       </ul>
   <td>
-      <ul>
-       <li>N/A
-      </ul>
+      <table>
+       <tr><th>
+       <tr><td>FLOAT32
+      </table>
 <tr>
   <td>GpuAcc
   <td>
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index b045530..3aea667 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -341,4 +341,24 @@
     return keyIndices;
 }
 
+armnn::PermutationVector GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank)
+{
+    armnn::PermutationVector permutationVector{};
+    switch (rank)
+    {
+        case 2:
+            permutationVector = {1U, 0U};
+            break;
+        case 3:
+            permutationVector = {0U, 2U, 1U};
+            break;
+        case 4:
+            permutationVector = {0U, 1U, 3U, 2U};
+            break;
+        default:
+            throw Exception("Invalid number of dimensions.");
+    }
+    return permutationVector;
+}
+
 } // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp
index 0e54873..3d8d927 100644
--- a/src/backends/backendsCommon/WorkloadUtils.hpp
+++ b/src/backends/backendsCommon/WorkloadUtils.hpp
@@ -258,4 +258,10 @@
 /// \return - A map with names and values for  N, ND, K, W, C
 std::map<std::string, unsigned int> CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1);
 
+/// Generates a permutation vector of size rank that permutes the 2 most right dimensions
+///
+/// \param rank - Tensor rank, i.e. number of dimensions in the tensors
+/// \return - A permutation vector that permutes the 2 last dimensions
+armnn::PermutationVector GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank);
+
 }  //namespace armnn
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
index 6fcc35a..74bd97f 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
@@ -71,20 +71,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({2,2}, ArmnnType, qScale, qOffset);
@@ -160,20 +149,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({1,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({1,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({1,2,2}, ArmnnType, qScale, qOffset);
@@ -249,20 +227,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Default arbitrary layout is treated the same as NCHW
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
@@ -343,20 +310,9 @@
                                                    armnn::DataLayout::NHWC,
                                                    armnn::DataLayout::NHWC);
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
@@ -432,20 +388,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({2,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
@@ -530,20 +475,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({1,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
@@ -625,20 +559,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
@@ -725,20 +648,9 @@
                                                    armnn::DataLayout::NDHWC,
                                                    armnn::DataLayout::NHWC);
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({1,1,2,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({1,2,2,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({1,1,2,2,2}, ArmnnType, qScale, qOffset);
@@ -823,20 +735,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({1,1}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({1,1}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({1,1}, ArmnnType, qScale, qOffset);
@@ -909,20 +810,9 @@
 {
     auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({2,5,3}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({2,3,4}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({2,5,4}, ArmnnType, qScale, qOffset);
@@ -1024,20 +914,9 @@
                                                    false,
                                                    false);
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({2,3}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({2,3}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
@@ -1117,20 +996,9 @@
                                                    true,
                                                    false);
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({3,3}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({3,3}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
@@ -1227,20 +1095,9 @@
                                                    armnn::DataLayout::NHWC,
                                                    armnn::DataLayout::NHWC);
 
-    float qScale = 0.0f;
+    float qScale = 1.0f;
     int32_t qOffset = 0;
 
-    switch(ArmnnType)
-    {
-        case armnn::DataType::QAsymmS8:
-        case armnn::DataType::QAsymmU8:
-        case armnn::DataType::QSymmS16:
-            qScale = 1.0f;
-            break;
-        default:
-            break;
-    }
-
     armnn::TensorInfo inputXInfo({1,4,4,2}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo inputYInfo({2,2,4,1}, ArmnnType, qScale, qOffset);
     armnn::TensorInfo outputInfo({2,4,2,2}, ArmnnType, qScale, qOffset);
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index cf541f4..7f311d8 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -24,6 +24,7 @@
 #include "workloads/NeonAdditionWorkload.hpp"
 #include "workloads/NeonActivationWorkload.hpp"
 #include "workloads/NeonArgMinMaxWorkload.hpp"
+#include "workloads/NeonBatchMatMulWorkload.hpp"
 #include "workloads/NeonBatchNormalizationWorkload.hpp"
 #include "workloads/NeonBatchToSpaceNdWorkload.hpp"
 #include "workloads/NeonCastWorkload.hpp"
@@ -171,6 +172,12 @@
                                         infos[1],
                                         *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
                                         reasonIfUnsupported);
+        case LayerType::BatchMatMul:
+            return IsBatchMatMulSupported(infos[0],
+                                          infos[1],
+                                          infos[2],
+                                          *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
+                                          reasonIfUnsupported);
         case LayerType::BatchNormalization:
             return IsBatchNormalizationSupported(infos[0],
                                                  infos[1],
@@ -627,6 +634,20 @@
                                    descriptor);
 }
 
+bool NeonLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
+                                              const TensorInfo& inputY,
+                                              const TensorInfo& output,
+                                              const BatchMatMulDescriptor& descriptor,
+                                              Optional<std::string&> reasonIfUnsupported) const
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(NeonBatchMatMulValidate,
+                                   reasonIfUnsupported,
+                                   inputX,
+                                   inputY,
+                                   output,
+                                   descriptor);
+}
+
 bool NeonLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
                                                      const TensorInfo& output,
                                                      const TensorInfo& mean,
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index 783e6a0..e916162 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -41,6 +41,12 @@
                               const ArgMinMaxDescriptor& descriptor,
                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
+    bool IsBatchMatMulSupported(const TensorInfo& inputX,
+                                const TensorInfo& inputY,
+                                const TensorInfo& output,
+                                const BatchMatMulDescriptor& descriptor,
+                                Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
+
     bool IsBatchNormalizationSupported(const TensorInfo& input,
                                        const TensorInfo& output,
                                        const TensorInfo& mean,
diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp
index ff9ef26..d5a7c68 100644
--- a/src/backends/neon/NeonWorkloadFactory.cpp
+++ b/src/backends/neon/NeonWorkloadFactory.cpp
@@ -152,6 +152,11 @@
             auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
             return std::make_unique<NeonArgMinMaxWorkload>(*argMinMaxQueueDescriptor, info);
         }
+        case LayerType::BatchMatMul :
+        {
+            auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor);
+            return std::make_unique<NeonBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info);
+        }
         case LayerType::BatchNormalization :
         {
             auto batchNormalizationQueueDescriptor
diff --git a/src/backends/neon/backend.mk b/src/backends/neon/backend.mk
index 7c0974c..b1c0103 100644
--- a/src/backends/neon/backend.mk
+++ b/src/backends/neon/backend.mk
@@ -26,6 +26,7 @@
         workloads/NeonActivationWorkload.cpp \
         workloads/NeonAdditionWorkload.cpp \
         workloads/NeonArgMinMaxWorkload.cpp \
+        workloads/NeonBatchMatMulWorkload.cpp \
         workloads/NeonBatchNormalizationWorkload.cpp \
         workloads/NeonBatchToSpaceNdWorkload.cpp \
         workloads/NeonCastWorkload.cpp \
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index 91fb4d7..88e513e 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -50,6 +50,23 @@
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchToSpaceNdNchwUint2, BatchToSpaceNdNchwTest2<DataType::QAsymmU8>)
 ARMNN_AUTO_TEST_CASE_WITH_THF(BatchToSpaceNdNchwUint3, BatchToSpaceNdNchwTest3<DataType::QAsymmU8>)
 
+// Batch Mat Mul
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat32, BatchMatMul2DSimpleTest<DataType::Float32>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat32, BatchMatMul3DSimpleTest<DataType::Float32>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat32, BatchMatMulNCHWSimpleTest<DataType::Float32>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat32, BatchMatMul3DBatchTest<DataType::Float32>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat32, BatchMatMul3DBroadcastTest<DataType::Float32>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat32, BatchMatMul3D2DBroadcastTest<DataType::Float32>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat32, BatchMatMul2DTinyTest<DataType::Float32>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat32, BatchMatMul2DTranspSimpleTest<DataType::Float32>);
+
 // Convolution
 ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleConvolution1d, Convolution1dTest, true)
 
diff --git a/src/backends/neon/workloads/CMakeLists.txt b/src/backends/neon/workloads/CMakeLists.txt
index 2209bf4..dd09ecf 100644
--- a/src/backends/neon/workloads/CMakeLists.txt
+++ b/src/backends/neon/workloads/CMakeLists.txt
@@ -12,6 +12,8 @@
     NeonAdditionWorkload.hpp
     NeonArgMinMaxWorkload.cpp
     NeonArgMinMaxWorkload.hpp
+    NeonBatchMatMulWorkload.cpp
+    NeonBatchMatMulWorkload.hpp
     NeonBatchNormalizationWorkload.cpp
     NeonBatchNormalizationWorkload.hpp
     NeonBatchToSpaceNdWorkload.cpp
diff --git a/src/backends/neon/workloads/NeonBatchMatMulWorkload.cpp b/src/backends/neon/workloads/NeonBatchMatMulWorkload.cpp
new file mode 100644
index 0000000..3d8651f
--- /dev/null
+++ b/src/backends/neon/workloads/NeonBatchMatMulWorkload.cpp
@@ -0,0 +1,190 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "NeonBatchMatMulWorkload.hpp"
+
+#include "NeonWorkloadUtils.hpp"
+
+#include <armnn/utility/PolymorphicDowncast.hpp>
+
+#include <armnnUtils/Permute.hpp>
+
+#include <backendsCommon/WorkloadUtils.hpp>
+
+#include <arm_compute/runtime/NEON/functions/NEGEMM.h>
+
+#include <arm_compute/runtime/NEON/functions/NEPermute.h>
+
+
+namespace armnn
+{
+arm_compute::Status NeonBatchMatMulValidate(const TensorInfo& inputX,
+                                            const TensorInfo& inputY,
+                                            const TensorInfo& output,
+                                            const BatchMatMulDescriptor& descriptor)
+{
+    if (descriptor.m_AdjointX || descriptor.m_AdjointY )
+    {
+        throw Exception("Support for adjoint not implemented.");
+    }
+    if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
+    {
+        throw Exception("Only supported the MatMul in the last 2 dimensions");
+    }
+
+    const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
+    const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
+    const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+    arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
+    arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
+    arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
+
+    arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
+    arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
+
+    if (descriptor.m_TransposeX == true)
+    {
+        auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
+        const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
+        const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
+        aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
+
+        statusPermuteX = arm_compute::NEPermute::validate(&aclInputXInfo,
+                                                          &aclPermutedXInfo,
+                                                          aclPermutationXVector);
+    }
+
+    if (descriptor.m_TransposeY == true)
+    {
+        auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
+        const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
+        const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
+        aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
+
+        statusPermuteY = arm_compute::NEPermute::validate(&aclInputYInfo,
+                                                          &aclPermutedYInfo,
+                                                          aclPermutationYVector);
+    }
+
+    const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false,  // is inputX reshaped
+                                                                   false,  // is inputY reshaped
+                                                                   false); // is inputY reshaped only 1st run
+
+    statusGEMM = arm_compute::NEGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
+                                               descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
+                                               nullptr,
+                                               &aclOutputInfo,
+                                               1.0,
+                                               0,
+                                               gemm_info);
+
+    if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
+        statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
+        statusGEMM.error_code()     == arm_compute::ErrorCode::OK)
+    {
+        return arm_compute::Status(arm_compute::ErrorCode::OK,
+                                   "All BatchMatMul layers validate status OK.");
+    }
+    else
+    {
+        return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
+                                   "BatchMatMul layer validate status failed."
+                                   + statusGEMM.error_description()
+                                   + statusPermuteX.error_description()
+                                   + statusPermuteY.error_description());
+    }
+
+}
+
+NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(
+    const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info)
+    : NeonBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
+{
+    if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
+    {
+        throw Exception("Support for adjoint not implemented.");
+    }
+    if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
+        descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
+    {
+        throw Exception("Only supported the MatMul in the last 2 dimensions");
+    }
+
+    // Report Profiling Details
+    ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonBatchMatMulWorkload_Construct",
+                                         descriptor.m_Parameters,
+                                         info,
+                                         this->GetGuid());
+
+    m_Data.ValidateInputsOutputs("NeonBatchMatMulWorkload", 2, 1);
+
+    arm_compute::ITensor& inputX = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+    arm_compute::ITensor& inputY = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+    auto outputHandle = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0]);
+    arm_compute::ITensor& output = outputHandle->GetTensor();
+
+    arm_compute::DataLayout aclDataLayoutX = ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX);
+    arm_compute::DataLayout aclDataLayoutY = ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY);
+
+    inputX.info()->set_data_layout(aclDataLayoutX);
+    inputY.info()->set_data_layout(aclDataLayoutY);
+
+    if (descriptor.m_Parameters.m_TransposeX == true)
+    {
+        armnn::PermutationVector permutationXVector
+                = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
+        const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
+        const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
+
+        auto permuteLayerX = std::make_unique<arm_compute::NEPermute>();
+        BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
+        InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
+        permuteLayerX->configure(&inputX, &m_PermutedTensorX, aclPermutationXVector);
+        m_PermuteLayerX.reset(permuteLayerX.release());
+    }
+
+    if (descriptor.m_Parameters.m_TransposeY == true)
+    {
+        armnn::PermutationVector permutationYVector
+                = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[1].GetNumDimensions());
+        const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[1], permutationYVector);
+        const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
+
+        auto permuteLayerY = std::make_unique<arm_compute::NEPermute>();
+        BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
+        InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
+        permuteLayerY->configure(&inputY, &m_PermutedTensorY, aclPermutationYVector);
+        m_PermuteLayerY.reset(permuteLayerY.release());
+    }
+
+    const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false,  // is inputX reshaped
+                                                                   false,  // is inputY reshaped
+                                                                   false); // is inputY reshaped only 1st run
+    auto gemmLayer = std::make_unique<arm_compute::NEGEMM>();
+    gemmLayer->configure(descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
+                         descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
+                         nullptr,
+                         &output,
+                         1.0,
+                         0,
+                         gemm_info);
+    m_GEMMLayer.reset(gemmLayer.release());
+}
+
+void NeonBatchMatMulWorkload::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonBatchMatMulWorkload_Execute", this->GetGuid());
+    if (m_PermuteLayerX)
+    {
+        m_PermuteLayerX->run();
+    }
+    if (m_PermuteLayerY)
+    {
+        m_PermuteLayerY->run();
+    }
+    m_GEMMLayer->run();
+}
+} //namespace armnn
diff --git a/src/backends/neon/workloads/NeonBatchMatMulWorkload.hpp b/src/backends/neon/workloads/NeonBatchMatMulWorkload.hpp
new file mode 100644
index 0000000..cb004d2
--- /dev/null
+++ b/src/backends/neon/workloads/NeonBatchMatMulWorkload.hpp
@@ -0,0 +1,41 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "NeonBaseWorkload.hpp"
+
+#include <arm_compute/runtime/IFunction.h>
+#include <arm_compute/runtime/Tensor.h>
+
+#include <memory>
+
+namespace armnn
+{
+    arm_compute::Status NeonBatchMatMulValidate(const TensorInfo& inputX,
+                                                const TensorInfo& inputY,
+                                                const TensorInfo& output,
+                                                const BatchMatMulDescriptor& descriptor);
+
+    class NeonBatchMatMulWorkload : public NeonBaseWorkload<BatchMatMulQueueDescriptor>
+    {
+    public:
+        NeonBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+                                const WorkloadInfo& info);
+        virtual void Execute() const override;
+
+    private:
+        // ACL layers required to fully form a Batch Mat Mul layer.
+        std::unique_ptr<arm_compute::IFunction> m_GEMMLayer;
+        std::unique_ptr<arm_compute::IFunction> m_PermuteLayerX;
+        std::unique_ptr<arm_compute::IFunction> m_PermuteLayerY;
+
+        // Additional ACL arm_compute::Tensors.
+        // Required to perform permutations.
+        arm_compute::Tensor m_PermutedTensorX;
+        arm_compute::Tensor m_PermutedTensorY;
+
+    };
+} //namespace armnn
diff --git a/src/backends/neon/workloads/NeonWorkloads.hpp b/src/backends/neon/workloads/NeonWorkloads.hpp
index 8f83674..c9c5421 100644
--- a/src/backends/neon/workloads/NeonWorkloads.hpp
+++ b/src/backends/neon/workloads/NeonWorkloads.hpp
@@ -8,6 +8,7 @@
 #include "NeonActivationWorkload.hpp"
 #include "NeonAdditionWorkload.hpp"
 #include "NeonArgMinMaxWorkload.hpp"
+#include "NeonBatchMatMulWorkload.hpp"
 #include "NeonBatchNormalizationWorkload.hpp"
 #include "NeonBatchToSpaceNdWorkload.hpp"
 #include "NeonCastWorkload.hpp"