IVGCVSW-6497: BatchMatMul TfLite Parser

  * Added armnnTfLiteParser for BatchMatMul
  * Added unit testing for parser
  * Updated CMakeLists

Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: If6842aaf7cf08f688093b714e2ecea6e8cd87161
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4e4818d..14236c7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -660,6 +660,7 @@
              src/armnnTfLiteParser/test/Addition.cpp
              src/armnnTfLiteParser/test/ArgMinMax.cpp
              src/armnnTfLiteParser/test/AvgPool2D.cpp
+             src/armnnTfLiteParser/test/BatchMatMul.cpp
              src/armnnTfLiteParser/test/BatchToSpaceND.cpp
              src/armnnTfLiteParser/test/Cast.cpp
              src/armnnTfLiteParser/test/Comparison.cpp
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
index acd089a..0f86b9d 100644
--- a/src/armnn/layers/BatchMatMulLayer.cpp
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -37,14 +37,14 @@
     TensorShape inputXShape = inputShapes[0];
     TensorShape inputYShape = inputShapes[1];
 
-    // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
-    if(m_Param.m_TransposeX)
+    // Adjoint is assumed to be square, but we will apply the permute anyway
+    if(m_Param.m_TransposeX || m_Param.m_AdjointX)
     {
         auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
                                                                inputXShape);
         inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
     }
-    if(m_Param.m_TransposeY)
+    if(m_Param.m_TransposeY || m_Param.m_AdjointY)
     {
         auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
                                                                inputYShape);
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 880de10..0304203 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -680,6 +680,7 @@
     m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX]                 = &TfLiteParserImpl::ParseArgMax;
     m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D]         = &TfLiteParserImpl::ParseAveragePool2D;
     m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND]       = &TfLiteParserImpl::ParseBatchToSpaceND;
+    m_ParserFunctions[tflite::BuiltinOperator_BATCH_MATMUL]            = &TfLiteParserImpl::ParseBatchMatMul;
     m_ParserFunctions[tflite::BuiltinOperator_CAST]                    = &TfLiteParserImpl::ParseCast;
     m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION]           = &TfLiteParserImpl::ParseConcatenation;
     m_ParserFunctions[tflite::BuiltinOperator_CONV_2D]                 = &TfLiteParserImpl::ParseConv2D;
@@ -1565,6 +1566,44 @@
     ParsePool(subgraphIndex, operatorIndex, PoolingAlgorithm::Average);
 }
 
+void TfLiteParserImpl::ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex)
+{
+    CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+    auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+    CHECK_VALID_SIZE(inputs.size(), 2);
+
+    auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+    CHECK_VALID_SIZE(outputs.size(), 1);
+
+    auto layerName = fmt::format("BatchMatMul:{}:{}", subgraphIndex, operatorIndex);
+
+    TensorInfo inputXTensorInfo = ToTensorInfo(inputs[0]);
+    TensorInfo inputYTensorInfo = ToTensorInfo(inputs[1]);
+
+    TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
+
+    const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+    const auto* options = operatorPtr->builtin_options.AsBatchMatMulOptions();
+
+    BatchMatMulDescriptor descriptor(false,
+                                     false,
+                                     options->adj_x,
+                                     options->adj_y);
+                                     // Arbitrary DataLayout
+
+    IConnectableLayer* layer = m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str());
+    ARMNN_ASSERT(layer != nullptr);
+
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+    auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+
+    auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+    RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
 void TfLiteParserImpl::ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex)
 {
     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 49744a0..f8ddc55 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -114,6 +114,7 @@
     void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
     void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
     void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
+    void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex);
     void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
     void ParseCast(size_t subgraphIndex, size_t operatorIndex);
     void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
diff --git a/src/armnnTfLiteParser/test/BatchMatMul.cpp b/src/armnnTfLiteParser/test/BatchMatMul.cpp
new file mode 100644
index 0000000..f4cdd67
--- /dev/null
+++ b/src/armnnTfLiteParser/test/BatchMatMul.cpp
@@ -0,0 +1,114 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserFlatbuffersFixture.hpp"
+
+TEST_SUITE("TensorflowLiteParser_BatchMatMul")
+{
+struct BatchMatMulFixture : public ParserFlatbuffersFixture
+{
+    explicit BatchMatMulFixture(const std::string &inputXShape,
+                                const std::string &inputYShape,
+                                const std::string &outputShape,
+                                const std::string &adjX,
+                                const std::string &adjY)
+    {
+        m_JsonString = R"(
+            {
+                "version": 3,
+                "operator_codes": [ { "builtin_code": "BATCH_MATMUL" } ],
+                "subgraphs": [
+                    {
+                        "tensors": [
+                            {
+                                "shape": )" + inputXShape + R"(,
+                                "type": "FLOAT32",
+                                "buffer": 0,
+                                "name": "inputXTensor",
+                                "quantization": {
+                                    "min": [ 0.0 ],
+                                    "max": [ 255.0 ],
+                                    "scale": [ 1.0 ],
+                                    "zero_point": [ 0 ],
+                                }
+                            },
+                            {
+                                "shape": )" + inputYShape + R"(,
+                                "type": "FLOAT32",
+                                "buffer": 1,
+                                "name": "inputYTensor",
+                                "quantization": {
+                                    "min": [ 0.0 ],
+                                    "max": [ 255.0 ],
+                                    "scale": [ 1.0 ],
+                                    "zero_point": [ 0 ],
+                                }
+                            },
+                            {
+                                "shape": )" + outputShape + R"(,
+                                "type": "FLOAT32",
+                                "buffer": 2,
+                                "name": "outputTensor",
+                                "quantization": {
+                                    "min": [ 0.0 ],
+                                    "max": [ 255.0 ],
+                                    "scale": [ 1.0 ],
+                                    "zero_point": [ 0 ],
+                                }
+                            }
+                        ],
+                        "inputs": [ 0, 1 ],
+                        "outputs": [ 2 ],
+                        "operators": [
+                            {
+                                "opcode_index": 0,
+                                "inputs": [ 0 , 1 ],
+                                "outputs": [ 2 ],
+                                "builtin_options_type": "BatchMatMulOptions",
+                                "builtin_options": {
+                                    adj_x: )" + adjX + R"(,
+                                    adj_y: )" + adjY + R"(,
+                                    "asymmetric_quantize_inputs": false
+                                },
+                                "custom_options_format": "FLEXBUFFERS"
+                            }
+                        ]
+                    }
+                ],
+                "buffers": [{},{}]
+            }
+        )";
+        Setup();
+    }
+};
+
+struct BatchMatMulParamsFixture : BatchMatMulFixture
+{
+    BatchMatMulParamsFixture()
+        : BatchMatMulFixture("[ 1, 3, 3 ]",
+                             "[ 1, 3, 3 ]",
+                             "[ 1, 3, 3 ]",
+                             "false",
+                             "true")
+    {}
+};
+
+TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams")
+{
+    RunTest<3, armnn::DataType::Float32>(
+        0,
+        {{"inputXTensor", {2.0f, 3.0f, 5.0f,
+                           8.0f, 13.0f, 21.0f,
+                           34.0f, 55.0f, 89.0f}},
+         {"inputYTensor", {0.0f, 1.0f, 1.0f,
+                           1.0f, 0.0f, 1.0f,
+                           1.0f, 1.0f, 0.0f}}},
+        {{"outputTensor", {6.0f, 4.0f, 0.0f,
+                           26.0f, 16.0f, 0.0f,
+                           110.0f, 68.0f, 0.0f}}}
+        );
+}
+
+}
\ No newline at end of file