IVGCVSW-6449 Add GEMM operator support to ONNX parser

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I3c6979c72d44a15fb2dc3afc22ac30d1428684b0
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b80dcad..8fd7123 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -766,6 +766,7 @@
             src/armnnOnnxParser/test/Flatten.cpp
             src/armnnOnnxParser/test/FullyConnected.cpp
             src/armnnOnnxParser/test/Gather.cpp
+            src/armnnOnnxParser/test/Gemm.cpp
             src/armnnOnnxParser/test/GetInputsOutputs.cpp
             src/armnnOnnxParser/test/OnnxParserTestUtils.cpp
             src/armnnOnnxParser/test/OnnxParserTestUtils.hpp
diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox
index 2304e15..adc3051 100644
--- a/docs/01_01_parsers.dox
+++ b/docs/01_01_parsers.dox
@@ -88,6 +88,8 @@
   - The parser only supports 2D convolutions with a group = 1 or group = #Nb_of_channel (depthwise convolution)
 - BatchNormalization
   - The parser does not support training mode. See the ONNX [BatchNormalization documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization) for more information.
+- Gemm
+  - The parser only supports constant bias or non-constant bias where bias dimension = 1. See the ONNX [Gemm documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm) for more information.
 - MatMul
   - The parser only supports constant weights in a fully connected layer.
 
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 6caf690..3588975 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -434,7 +434,8 @@
     { "Shape",                 &OnnxParserImpl::ParseShape },
     { "Gather",                &OnnxParserImpl::ParseGather },
     { "Unsqueeze",             &OnnxParserImpl::ParseUnsqueeze },
-    { "Concat",                &OnnxParserImpl::ParseConcat }
+    { "Concat",                &OnnxParserImpl::ParseConcat },
+    { "Gemm",                  &OnnxParserImpl::ParseGemm }
 };
 
 template<typename TypePair, typename Location>
@@ -1800,6 +1801,175 @@
     RegisterOutputSlots(layer, { node.output(0) });
 }
 
+void OnnxParserImpl::ParseGemm(const onnx::NodeProto& node)
+{
+    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3);
+    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+    int transA = static_cast<int>(ReadOptionalNodeUint32Attribute(node, "transA", 0));
+    int transB = static_cast<int>(ReadOptionalNodeUint32Attribute(node, "transB", 0));
+    float alpha = ReadOptionalNodeFloatAttribute(node, "alpha", 1.0);
+    float beta = ReadOptionalNodeFloatAttribute(node, "beta", 1.0);
+    bool biasEnabled = node.input_size() == 3;
+
+    TensorShape input0Shape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+    TensorShape input1Shape = m_TensorsInfo[node.input(1)].m_info->GetShape();
+
+    // if transB != 0, add transpose to the input1 (tanspose weight matrix in FullyConnected)
+    armnn::FullyConnectedDescriptor fullyConnectedDescriptor;
+    fullyConnectedDescriptor.m_BiasEnabled = biasEnabled;
+    fullyConnectedDescriptor.m_TransposeWeightMatrix = transB;
+
+    IConnectableLayer* layer = nullptr;
+
+    // Just add a FullyConnected layer, weights and biases are handled as inputs now.
+    layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor, node.name().c_str());
+    ARMNN_ASSERT(layer != nullptr);
+
+    // if transA != 0, add transpose to the input0
+    if (transA != 0)
+    {
+        std::string transAName = "transpose_" + node.input(0);
+        armnn::TransposeDescriptor transposeADescriptor;
+        transposeADescriptor.m_DimMappings = { 1, 0 };
+        IConnectableLayer* transALayer = m_Network->AddTransposeLayer(transposeADescriptor, transAName.c_str());
+        ARMNN_ASSERT(transALayer != nullptr);
+        auto transAInfo = ComputeOutputInfo({ transAName }, transALayer, { input0Shape });
+        transALayer->GetOutputSlot(0).SetTensorInfo(transAInfo[0]);
+        transALayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
+        // register the input connection slots for the layer, connections are made after all layers have been created
+        RegisterInputSlot(transALayer, node.input(0), 0);
+        input0Shape = transAInfo[0].GetShape();
+    }
+    else
+    {
+        RegisterInputSlot(layer, node.input(0), 0);
+    }
+
+    // Add constant layer to store weights/biases and connect to FullyConnected layer.
+    if(m_TensorsInfo[node.input(1)].isConstant())
+    {
+        IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(1)).first);
+        TensorInfo weightInfo = *m_TensorsInfo[node.input(1)].m_info;
+        weightInfo.SetConstant();
+        weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
+
+        // if alpha != 1, multiply to the weight
+        if (alpha != 1)
+        {
+            std::string activationName = "activation_" + node.input(1);
+            armnn::ActivationDescriptor activationDescriptor;
+            activationDescriptor.m_A = alpha;
+            activationDescriptor.m_Function = ActivationFunction::Linear;
+            IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
+            ARMNN_ASSERT(actLayer != nullptr);
+
+            auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { weightInfo.GetShape() });
+            actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
+            actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+            weightsLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u));
+            input1Shape = actInfo[0].GetShape();
+        }
+        else
+        {
+            weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+            input1Shape = weightInfo.GetShape();
+        }
+    }
+    else
+    {
+        // if alpha != 1, multiply to the weight
+        if (alpha != 1)
+        {
+            std::string activationName = "activation_" + node.input(1);
+            armnn::ActivationDescriptor activationDescriptor;
+            activationDescriptor.m_A = alpha;
+            activationDescriptor.m_Function = ActivationFunction::Linear;
+            IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
+            ARMNN_ASSERT(actLayer != nullptr);
+
+            auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { input1Shape });
+            actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
+            actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+            RegisterInputSlot(actLayer, node.input(1), 0);
+            input1Shape = actInfo[0].GetShape();
+        }
+        else
+        {
+            RegisterInputSlot(layer, node.input(1), 1);
+        }
+    }
+
+    if(biasEnabled && m_TensorsInfo[node.input(2)].isConstant())
+    {
+        To1DTensor(node.input(2), CHECK_LOCATION());
+        IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(2)).first);
+        TensorInfo biasInfo = *m_TensorsInfo[node.input(2)].m_info;
+        biasInfo.SetConstant();
+        biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
+
+        // if beta != 1, multiply to the bias
+        if (beta != 1)
+        {
+            std::string activationName = "activation_" + node.input(2);
+            armnn::ActivationDescriptor activationDescriptor;
+            activationDescriptor.m_A = beta;
+            activationDescriptor.m_Function = ActivationFunction::Linear;
+            IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
+            ARMNN_ASSERT(actLayer != nullptr);
+
+            auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { biasInfo.GetShape() });
+            actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
+            actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
+            biasLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u));
+        }
+        else
+        {
+            biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
+        }
+    }
+    else if (biasEnabled)
+    {
+        // Currently we support non-constant tensor of input C (bias) of Gemm when the dimension is 1
+        if (m_TensorsInfo[node.input(2)].m_info->GetNumDimensions() != 1)
+        {
+            throw ParseException(fmt::format("The parser supports constant or non-constant with 1 dimension for "
+                                             "Input C of Gemm. Input '{}' in '{}' is not supported '{}'",
+                                             node.input(2),
+                                             node.name(),
+                                             CHECK_LOCATION().AsString()));
+        }
+        // if beta != 1, multiply to the bias
+        if (beta != 1)
+        {
+            std::string activationName = "activation_" + node.input(2);
+            armnn::ActivationDescriptor activationDescriptor;
+            activationDescriptor.m_A = beta;
+            activationDescriptor.m_Function = ActivationFunction::Linear;
+            IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
+            ARMNN_ASSERT(actLayer != nullptr);
+
+            auto actInfo = ComputeOutputInfo({ activationName },
+                                             actLayer,
+                                             { m_TensorsInfo[node.input(2)].m_info->GetShape() });
+            actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
+            actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
+            RegisterInputSlot(actLayer, node.input(2), 0);
+        }
+        else
+        {
+            RegisterInputSlot(layer, node.input(2), 2);
+        }
+    }
+
+    // Set final output of the FullyConnected layer
+    auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
+                                        { input0Shape, input1Shape });
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+    RegisterOutputSlots(layer, {node.output(0)});
+}
+
 void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
 {
     Pooling2dDescriptor desc = Pooling2dDescriptor();
@@ -2031,6 +2201,22 @@
     }
 }
 
+void OnnxParserImpl::RegisterInputSlot(IConnectableLayer* layer,
+                                       const std::string& tensorId,
+                                       unsigned int slotIndex)
+{
+    armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
+
+    auto it = m_TensorConnections.find(tensorId);
+
+    if (it == m_TensorConnections.end())
+    {
+        //First time seing this tensor, we need to map it
+        m_TensorConnections[tensorId] = TensorSlots();
+    }
+    m_TensorConnections[tensorId].inputSlots.push_back(slot);
+}
+
 void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
 {
     ARMNN_ASSERT(layer != nullptr);
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index d388f50..ec19006 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -120,12 +120,16 @@
     void ParseConv(const onnx::NodeProto& nodeProto);
     void ParseFlatten(const onnx::NodeProto& node);
     void ParseGather(const onnx::NodeProto& node);
+    void ParseGemm(const onnx::NodeProto& node);
     void ParseGlobalAveragePool(const onnx::NodeProto& node);
     void ParseMaxPool(const onnx::NodeProto& nodeProto);
     void ParseShape(const onnx::NodeProto& node);
     void ParseReshape(const onnx::NodeProto& nodeProto);
     void ParseUnsqueeze(const onnx::NodeProto& nodeProto);
 
+    void RegisterInputSlot(armnn::IConnectableLayer* layer,
+                           const std::string& tensorId,
+                           unsigned int slotIndex);
     void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
     void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
 
diff --git a/src/armnnOnnxParser/test/Gemm.cpp b/src/armnnOnnxParser/test/Gemm.cpp
new file mode 100644
index 0000000..f68758f
--- /dev/null
+++ b/src/armnnOnnxParser/test/Gemm.cpp
@@ -0,0 +1,556 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "armnnOnnxParser/IOnnxParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+#include "OnnxParserTestUtils.hpp"
+
+TEST_SUITE("OnnxParser_Gemm")
+{
+
+struct GemmFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    GemmFixture(const std::string& alpha,
+                const std::string& beta,
+                const std::string& transA,
+                const std::string& transB,
+                const std::vector<int>& inputAShape,
+                const std::vector<int>& inputBShape,
+                const std::vector<int>& inputCShape,
+                const std::vector<int>& outputShape)
+    {
+        m_Prototext = R"(
+                    ir_version: 8
+                    producer_name: "onnx-example"
+                    graph {
+                      node {
+                        input: "A"
+                        input: "B"
+                        input: "C"
+                        output: "Output"
+                        op_type: "Gemm"
+                        attribute {
+                          name: "alpha"
+                          f: )" + alpha + R"(
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "beta"
+                          f: )" + beta + R"(
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "transA"
+                          i: )" + transA + R"(
+                          type: INT
+                        }
+                        attribute {
+                          name: "transB"
+                          i: )" + transB + R"(
+                          type: INT
+                        }
+                      }
+                      name: "gem-model"
+                      input {
+                        name: "A"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                      input {
+                        name: "B"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                      input {
+                        name: "C"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                      output {
+                        name: "Output"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                    })";
+    }
+};
+
+struct GemmAllAttributesFixture : GemmFixture
+{
+    GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 })
+    {
+        Setup();
+    }
+};
+
+struct GemmSimpleFixture : GemmFixture
+{
+    GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 })
+    {
+        Setup();
+    }
+};
+
+struct GemmTransAFixture : GemmFixture
+{
+    GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 })
+    {
+        Setup();
+    }
+};
+
+struct GemmTransBFixture : GemmFixture
+{
+    GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 })
+    {
+        Setup();
+    }
+};
+
+struct GemmParseExceptionFixture : GemmFixture
+{
+    GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {}
+};
+
+TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
+                       {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                               6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                               11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                               16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
+                       {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
+                      {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
+                                    12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
+                                    10.035f, 32.07f,  54.105f, 76.14f, 98.175f }}});
+}
+
+TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
+                       {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                               6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                               11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                               16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
+                       {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
+                      {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
+                                    196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
+                                    60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
+}
+
+TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
+                       {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                               6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                               11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                               16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
+                       {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
+                      {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f,
+                                    146.1f, 172.2f, 198.3f, 224.4f, 250.5f,
+                                    112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}});
+}
+
+TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
+                       {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                               6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                               11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                               16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
+                       {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
+                      {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f,
+                                    60.1f, 164.2f, 268.3f, 372.4f, 476.5f,
+                                    20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}});
+}
+
+TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest")
+{
+    // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension)
+    CHECK_THROWS_AS(Setup(), armnn::ParseException);
+}
+
+struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    GemmConstantFixture()
+    {
+        m_Prototext = R"(
+                    ir_version: 8
+                    producer_name: "onnx-example"
+                    graph {
+                      node {
+                        input: "A"
+                        input: "B"
+                        input: "C"
+                        output: "Output"
+                        op_type: "Gemm"
+                        attribute {
+                          name: "alpha"
+                          f: 0.25
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "beta"
+                          f: 0.35
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "transA"
+                          i: 1
+                          type: INT
+                        }
+                        attribute {
+                          name: "transB"
+                          i: 1
+                          type: INT
+                        }
+                      }
+                      name: "gem-model"
+                      initializer {
+                        dims: 5
+                        dims: 4
+                        data_type: 1
+                        float_data: 1.0
+                        float_data: 2.0
+                        float_data: 3.0
+                        float_data: 4.0
+                        float_data: 5.0
+                        float_data: 6.0
+                        float_data: 7.0
+                        float_data: 8.0
+                        float_data: 9.0
+                        float_data: 10.0
+                        float_data: 11.0
+                        float_data: 12.0
+                        float_data: 13.0
+                        float_data: 14.0
+                        float_data: 15.0
+                        float_data: 16.0
+                        float_data: 17.0
+                        float_data: 18.0
+                        float_data: 19.0
+                        float_data: 20.0
+                        name: "B"
+                      }
+                      initializer {
+                        dims: 1
+                        dims: 5
+                        data_type: 1
+                        float_data: 0.1
+                        float_data: 0.2
+                        float_data: 0.3
+                        float_data: 0.4
+                        float_data: 0.5
+                        name: "C"
+                      }
+                      input {
+                        name: "A"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              dim {
+                                dim_value: 4
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                            }
+                          }
+                        }
+                      }
+                      output {
+                        name: "Output"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 5
+                              }
+                            }
+                          }
+                        }
+                      }
+                    })";
+        Setup();
+    }
+};
+
+TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
+                      {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
+                                    12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
+                                    10.035f, 32.07f,  54.105f, 76.14f, 98.175f }}});
+}
+
+struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    GemmConstantSimpleFixture()
+    {
+        m_Prototext = R"(
+                    ir_version: 8
+                    producer_name: "onnx-example"
+                    graph {
+                      node {
+                        input: "A"
+                        input: "B"
+                        input: "C"
+                        output: "Output"
+                        op_type: "Gemm"
+                        attribute {
+                          name: "alpha"
+                          f: 1
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "beta"
+                          f: 1
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "transA"
+                          i: 0
+                          type: INT
+                        }
+                        attribute {
+                          name: "transB"
+                          i: 0
+                          type: INT
+                        }
+                      }
+                      name: "gem-model"
+                      initializer {
+                        dims: 4
+                        dims: 5
+                        data_type: 1
+                        float_data: 1.0
+                        float_data: 2.0
+                        float_data: 3.0
+                        float_data: 4.0
+                        float_data: 5.0
+                        float_data: 6.0
+                        float_data: 7.0
+                        float_data: 8.0
+                        float_data: 9.0
+                        float_data: 10.0
+                        float_data: 11.0
+                        float_data: 12.0
+                        float_data: 13.0
+                        float_data: 14.0
+                        float_data: 15.0
+                        float_data: 16.0
+                        float_data: 17.0
+                        float_data: 18.0
+                        float_data: 19.0
+                        float_data: 20.0
+                        name: "B"
+                      }
+                      initializer {
+                        dims: 1
+                        dims: 5
+                        data_type: 1
+                        float_data: 0.1
+                        float_data: 0.2
+                        float_data: 0.3
+                        float_data: 0.4
+                        float_data: 0.5
+                        name: "C"
+                      }
+                      input {
+                        name: "A"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 4
+                              }
+                            }
+                          }
+                        }
+                      }
+                      output {
+                        name: "Output"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 5
+                              }
+                            }
+                          }
+                        }
+                      }
+                    })";
+        Setup();
+    }
+};
+
+TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
+                      {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
+                                    196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
+                                    60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
+}
+
+struct GemmABFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    GemmABFixture(const std::string& alpha,
+                  const std::string& beta,
+                  const std::string& transA,
+                  const std::string& transB,
+                  const std::vector<int>& inputAShape,
+                  const std::vector<int>& inputBShape,
+                  const std::vector<int>& outputShape)
+    {
+        m_Prototext = R"(
+                    ir_version: 8
+                    producer_name: "onnx-example"
+                    graph {
+                      node {
+                        input: "A"
+                        input: "B"
+                        output: "Output"
+                        op_type: "Gemm"
+                        attribute {
+                          name: "alpha"
+                          f: )" + alpha + R"(
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "beta"
+                          f: )" + beta + R"(
+                          type: FLOAT
+                        }
+                        attribute {
+                          name: "transA"
+                          i: )" + transA + R"(
+                          type: INT
+                        }
+                        attribute {
+                          name: "transB"
+                          i: )" + transB + R"(
+                          type: INT
+                        }
+                      }
+                      name: "gem-model"
+                      input {
+                        name: "A"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                      input {
+                        name: "B"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                      output {
+                        name: "Output"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                    })";
+        Setup();
+    }
+};
+
+struct GemmAlphaTransAFixture : GemmABFixture
+{
+    GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {}
+};
+
+struct GemmAlphaTransBFixture : GemmABFixture
+{
+    GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {}
+};
+
+TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
+                       {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                               6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                               11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                               16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
+                      {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f,
+                                    36.5f, 43.0f, 49.5f, 56.0f, 62.5f,
+                                    28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}});
+}
+
+TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest")
+{
+    RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
+                               6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
+                       {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                               6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                               11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                               16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
+                      {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f,
+                                    15.0f, 41.0f, 67.0f, 93.0f, 119.0f,
+                                    5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}});
+}
+
+}