IVGCVSW-6075 Add ParseExpandDims to TfliteParser

* Add ExpandDims tests in tfliteparser
* Add support for negative axis to squeeze

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I604c9b4ac6514895e9e3d4d85c2937e797d288e0
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1ccc80d..b2f32ef 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -671,6 +671,7 @@
              src/armnnTfLiteParser/test/DetectionPostProcess.cpp
              src/armnnTfLiteParser/test/Div.cpp
              src/armnnTfLiteParser/test/ElementWiseUnary.cpp
+             src/armnnTfLiteParser/test/ExpandDims.cpp
              src/armnnTfLiteParser/test/FullyConnected.cpp
              src/armnnTfLiteParser/test/Gather.cpp
              src/armnnTfLiteParser/test/L2Normalization.cpp
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index f38f45f..2df47eb 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -623,6 +623,7 @@
     m_ParserFunctions[tflite::BuiltinOperator_DIV]                     = &TfLiteParserImpl::ParseDiv;
     m_ParserFunctions[tflite::BuiltinOperator_ELU]                     = &TfLiteParserImpl::ParseElu;
     m_ParserFunctions[tflite::BuiltinOperator_EXP]                     = &TfLiteParserImpl::ParseExp;
+    m_ParserFunctions[tflite::BuiltinOperator_EXPAND_DIMS]             = &TfLiteParserImpl::ParseExpandDims;
     m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED]         = &TfLiteParserImpl::ParseFullyConnected;
     m_ParserFunctions[tflite::BuiltinOperator_GATHER]                  = &TfLiteParserImpl::ParseGather;
     m_ParserFunctions[tflite::BuiltinOperator_HARD_SWISH]              = &TfLiteParserImpl::ParseHardSwish;
@@ -1091,6 +1092,37 @@
     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
 }
 
+void TfLiteParserImpl::ParseExpandDims(size_t subgraphIndex, size_t operatorIndex)
+{
+    CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+    auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+    CHECK_VALID_SIZE(inputs.size(), 2);
+
+    auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+    CHECK_VALID_SIZE(outputs.size(), 1);
+
+    auto layerName = fmt::format("ExpandDims:{}:{}", subgraphIndex, operatorIndex);
+
+    armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+    armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
+
+    CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
+
+    ReshapeDescriptor reshapeDesc;
+    reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+
+    IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
+    ARMNN_ASSERT(layer != nullptr);
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+    auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+    auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+    RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
 void TfLiteParserImpl::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
 {
     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
@@ -1586,11 +1618,10 @@
     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
 }
 
-armnn::TensorInfo TfLiteParserImpl::OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDimsIn,
+armnn::TensorInfo TfLiteParserImpl::OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
                                                          const armnn::TensorInfo & inputTensorInfo)
 {
-    CHECK_VALID_SIZE(squeezeDimsIn.size(), 0, 1, 2, 3, 4);
-    std::vector<uint32_t> squeezeDims = squeezeDimsIn;
+    CHECK_VALID_SIZE(squeezeDims.size(), 0, 1, 2, 3, 4);
     static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 };
 
     if (inputTensorInfo.GetNumDimensions() > 4)
@@ -1688,9 +1719,22 @@
     auto layerName = fmt::format("Squeeze:{}:{}", subgraphIndex, operatorIndex);
 
     armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
-    armnn::TensorInfo outputTensorInfo =
-        TfLiteParserImpl::OutputShapeOfSqueeze(AsUnsignedVector(options->squeeze_dims),
-                                           inputTensorInfo);
+
+    std::vector<uint32_t> squeezeDim;
+    // A single negative dim index is interpreted as a negative index in python
+    // Meaning the index will be the shape size plus the negative index value
+    if (options->squeeze_dims.size() == 1 && options->squeeze_dims[0] < 0)
+    {
+        int32_t dim = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions()) + options->squeeze_dims[0];
+        squeezeDim.push_back(static_cast<uint32_t>(dim));
+    }
+    else
+    {
+        squeezeDim = AsUnsignedVector(options->squeeze_dims);
+    }
+
+    armnn::TensorInfo outputTensorInfo = TfLiteParserImpl::OutputShapeOfSqueeze(squeezeDim, inputTensorInfo);
+
     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
 
     ReshapeDescriptor reshapeDesc;
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 836c4e8..49ccd27 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -64,20 +64,20 @@
 
 public:
     // testable helpers
-    static ModelPtr LoadModelFromFile(const char * fileName);
-    static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len);
-    static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
-    static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
-    static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex);
-    static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex);
+    static ModelPtr LoadModelFromFile(const char* fileName);
+    static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len);
+    static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
+    static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
+    static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex);
+    static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex);
     static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
     static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
 
     static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
-    static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims,
-                                                  const armnn::TensorInfo & inputTensorInfo);
-    static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
-                                                  const std::vector<int32_t> & targetDimsIn);
+    static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
+                                                  const armnn::TensorInfo& inputTensorInfo);
+    static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
+                                                  const std::vector<int32_t>& targetDimsIn);
 
     /// Retrieve version in X.Y.Z form
     static const std::string GetVersion();
@@ -116,6 +116,7 @@
     void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation);
     void ParseElu(size_t subgraphIndex, size_t operatorIndex);
     void ParseExp(size_t subgraphIndex, size_t operatorIndex);
+    void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex);
     void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
     void ParseGather(size_t subgraphIndex, size_t operatorIndex);
     void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
diff --git a/src/armnnTfLiteParser/test/ExpandDims.cpp b/src/armnnTfLiteParser/test/ExpandDims.cpp
new file mode 100644
index 0000000..a9f021f
--- /dev/null
+++ b/src/armnnTfLiteParser/test/ExpandDims.cpp
@@ -0,0 +1,106 @@
+//
+// Copyright © 2021 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+#include <string>
+#include <iostream>
+
+TEST_SUITE("TensorflowLiteParser_ExpandDims")
+{
+struct ExpandDimsFixture : public ParserFlatbuffersFixture
+{
+    explicit ExpandDimsFixture(const std::string& inputShape,
+                               const std::string& outputShape,
+                               const std::string& axis)
+    {
+        m_JsonString = R"(
+            {
+                "version": 3,
+                "operator_codes": [ { "builtin_code": "EXPAND_DIMS" } ],
+                "subgraphs": [ {
+                    "tensors": [
+                        {
+                            "shape": )" + inputShape + R"(,
+                            "type": "UINT8",
+                            "buffer": 0,
+                            "name": "inputTensor",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": )" + outputShape + R"( ,
+                            "type": "UINT8",
+                            "buffer": 1,
+                            "name": "outputTensor",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": [ 1 ],
+                            "type": "UINT8",
+                            "buffer": 2,
+                            "name": "expand_dims",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                    ],
+                    "inputs": [ 0 ],
+                    "outputs": [ 1 ],
+                    "operators": [
+                        {
+                            "opcode_index": 0,
+                            "inputs": [ 0 , 2 ],
+                            "outputs": [ 1 ],
+                            "custom_options_format": "FLEXBUFFERS"
+                        }
+                    ],
+                } ],
+                "buffers" : [
+                    { },
+                    { },
+                    { "data": )" + axis + R"(, },
+                ]
+            }
+        )";
+        SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+    }
+};
+
+struct ExpandDimsFixture3dto4Daxis0 : ExpandDimsFixture
+{
+    ExpandDimsFixture3dto4Daxis0() : ExpandDimsFixture("[ 2, 2, 1 ]", "[ 1, 2, 2, 1 ]", "[ 0, 0, 0, 0 ]") {}
+};
+
+TEST_CASE_FIXTURE(ExpandDimsFixture3dto4Daxis0, "ParseExpandDims3Dto4Daxis0")
+{
+    RunTest<4, armnn::DataType::QAsymmU8>(0, {{ "inputTensor",  { 1, 2, 3, 4 } } },
+                                             {{ "outputTensor", { 1, 2, 3, 4 } } });
+}
+
+struct ExpandDimsFixture3dto4Daxis3 : ExpandDimsFixture
+{
+    ExpandDimsFixture3dto4Daxis3() : ExpandDimsFixture("[ 1, 2, 2 ]", "[ 1, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]") {}
+};
+
+TEST_CASE_FIXTURE(ExpandDimsFixture3dto4Daxis3, "ParseExpandDims3Dto4Daxis3")
+{
+    RunTest<4, armnn::DataType::QAsymmU8>(0, {{ "inputTensor",  { 1, 2, 3, 4 } } },
+                                             {{ "outputTensor", { 1, 2, 3, 4 } } });
+}
+
+}
\ No newline at end of file
diff --git a/src/armnnTfLiteParser/test/Squeeze.cpp b/src/armnnTfLiteParser/test/Squeeze.cpp
index da870fd..6f533ba 100644
--- a/src/armnnTfLiteParser/test/Squeeze.cpp
+++ b/src/armnnTfLiteParser/test/Squeeze.cpp
@@ -128,14 +128,44 @@
 }
 
 
-struct SqueezeFixtureWithNegativeSqueezeDims : SqueezeFixture
+struct SqueezeFixtureWithNegativeSqueezeDims1 : SqueezeFixture
 {
-    SqueezeFixtureWithNegativeSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]",
-                                                             "[ 1, 2, 2, 1 ]",
-                                                             "[ -2 , 2 ]") {}
+    SqueezeFixtureWithNegativeSqueezeDims1() : SqueezeFixture("[ 1, 2, 2, 1 ]",
+                                                             "[ 2, 2, 1 ]",
+                                                             "[ -1 ]") {}
 };
 
-TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDims, "ParseSqueezeNegativeSqueezeDims")
+TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDims1, "ParseSqueezeNegativeSqueezeDims1")
+{
+    SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+    RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
+            CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+                   == armnn::TensorShape({ 2, 2, 1 })));
+}
+
+struct SqueezeFixtureWithNegativeSqueezeDims2 : SqueezeFixture
+{
+    SqueezeFixtureWithNegativeSqueezeDims2() : SqueezeFixture("[ 1, 2, 2, 1 ]",
+                                                              "[ 1, 2, 2 ]",
+                                                              "[ -1 ]") {}
+};
+
+TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDims2, "ParseSqueezeNegativeSqueezeDims2")
+{
+    SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+    RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
+            CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+                   == armnn::TensorShape({ 1, 2, 2 })));
+}
+
+struct SqueezeFixtureWithNegativeSqueezeDimsInvalid : SqueezeFixture
+{
+    SqueezeFixtureWithNegativeSqueezeDimsInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
+                                                                    "[ 1, 2, 2, 1 ]",
+                                                                    "[ -2 , 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDimsInvalid, "ParseSqueezeNegativeSqueezeDimsInvalid")
 {
     CHECK_THROWS_AS((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
 }