IVGCVSW-2193 ExpandDims operation implementation

* Add ExpandDims operation to TfParser.cpp

Change-Id: Ifa756ae0667c11e3b6daec8f6dd4e54cac88d16a
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 4ddcdce..53cdfa3 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -159,6 +159,17 @@
     return attribValue;
+int32_t ReadMandatoryNodeInt32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+    int32_t attribValue = 0u;
+    ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kI,
+                                   [&attribValue](const tensorflow::AttrValue& attrValue)
+                                   {
+                                       attribValue = static_cast<int32_t>(attrValue.i());
+                                   });
+    return attribValue;
 uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
     uint32_t attribValue = 0u;
@@ -349,6 +360,7 @@
     { "Identity",              &TfParser::ParseIdentity },
     { "Conv2D",                &TfParser::ParseConv2D },
     { "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
+    { "ExpandDims",            &TfParser::ParseExpandDims },
     { "FusedBatchNorm",        &TfParser::ParseFusedBatchNorm },
     { "ConcatV2",              &TfParser::ParseConcat },
     { "LRN",                   &TfParser::ParseLrn },
@@ -1224,6 +1236,100 @@
     return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo)
+    BOOST_ASSERT(nodeDef.op() == "ExpandDims");
+    if (inputTensorInfo.GetNumDimensions() > 4) {
+        throw ParseException(
+                boost::str(
+                        boost::format(
+                                "Unsupported number of dimensions: %1% for input shape for ExpandDims %2% %3%")
+                        % inputTensorInfo.GetNumDimensions()
+                        % nodeDef.name()
+                        % CHECK_LOCATION().AsString()));
+    }
+    std::int32_t expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim");
+    std::int32_t inputDimSize = boost::numeric_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+    std::vector<uint32_t> outputDims;
+    // expandDim operation requires: -1-input.dims() <= dim <= input.dims()
+    if (expandDim >= -1 - inputDimSize && expandDim <= inputDimSize)
+    {
+        // add current input shape to outputDims
+        for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i) {
+            auto currentDimension = inputTensorInfo.GetShape()[i];
+            outputDims.push_back(currentDimension);
+        }
+        // insert a dimension of 1 at index 'expandDim' of inputs shape
+        if (expandDim >= 0)
+        {
+            auto getPosition = std::next(outputDims.begin() + 0, expandDim);
+            outputDims.insert(getPosition, 1);
+        }
+        // if negative number for 'expandDim' then count backwards from the last element
+        // and insert 1 dimension at index 'expandDim'
+        if (expandDim < 0)
+        {
+            auto outputDimSize = boost::numeric_cast<uint32_t>(outputDims.size() + 1);
+            auto getPosition = std::next(outputDims.begin() + outputDimSize, expandDim);
+            outputDims.insert(getPosition, 1);
+        }
+    }
+    else
+    {
+        throw InvalidArgumentException(
+                boost::str(
+                        boost::format(
+                                "Cannot expand dimension %1% in input tensor with %2% dimension %3%")
+                        % expandDim
+                        % inputDimSize
+                        % CHECK_LOCATION().AsString()));
+    }
+    if (outputDims.size() > 4)
+    {
+        throw ParseException(
+                boost::str(
+                        boost::format(
+                                "Unsupported number of dimensions: %1% for output shape for ExpandDims %2% %3%")
+                        % outputDims.size()
+                        % nodeDef.name()
+                        % CHECK_LOCATION().AsString()));
+    }
+    TensorShape outShape = TensorShape(static_cast<unsigned int>(outputDims.size()),
+                                       outputDims.data());
+    TensorInfo outTensorInfo = inputTensorInfo;
+    outTensorInfo.SetShape(outShape);
+    return outTensorInfo;
+ParsedTfOperationPtr TfParser::ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+    std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+    IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+    TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+    TensorInfo outputInfo;
+    outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo);
+    ReshapeDescriptor reshapeDesc;
+    reshapeDesc.m_TargetShape = outputInfo.GetShape();
+    IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str());
+    prevLayerOutputSlot.Connect(layer->GetInputSlot(0));
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+    return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
 ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef,
                                                    const tensorflow::GraphDef& graphDef)
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 1c29ce2..da78f48 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -131,6 +131,7 @@
     ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
+    ParsedTfOperationPtr ParseExpandDims(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
diff --git a/src/armnnTfParser/test/ExpandDims.cpp b/src/armnnTfParser/test/ExpandDims.cpp
new file mode 100644
index 0000000..57d472d
--- /dev/null
+++ b/src/armnnTfParser/test/ExpandDims.cpp
@@ -0,0 +1,112 @@
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+    ExpandDimsFixture(const std::string& expandDim)
+    {
+        m_Prototext =
+                "node { \n"
+                "    name: \"graphInput\" \n"
+                "    op: \"Placeholder\" \n"
+                "    attr { \n"
+                "      key: \"dtype\" \n"
+                "      value { \n"
+                "        type: DT_FLOAT \n"
+                "      } \n"
+                "    } \n"
+                "    attr { \n"
+                "      key: \"shape\" \n"
+                "      value { \n"
+                "        shape { \n"
+                "        } \n"
+                "      } \n"
+                "    } \n"
+                "  } \n"
+                "node { \n"
+                "  name: \"ExpandDims\" \n"
+                "  op: \"ExpandDims\" \n"
+                "  input: \"graphInput\" \n"
+                "  attr { \n"
+                "    key: \"T\" \n"
+                "    value { \n"
+                "      type: DT_FLOAT \n"
+                "    } \n"
+                "  } \n"
+                "  attr { \n"
+                "    key: \"Tdim\" \n"
+                "    value { \n";
+            m_Prototext += "i:" + expandDim;
+            m_Prototext +=
+                "    } \n"
+                "  } \n"
+                "} \n";
+        SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims");
+    }
+struct ExpandZeroDim : ExpandDimsFixture
+    ExpandZeroDim() : ExpandDimsFixture("0") {}
+struct ExpandTwoDim : ExpandDimsFixture
+    ExpandTwoDim() : ExpandDimsFixture("2") {}
+struct ExpandThreeDim : ExpandDimsFixture
+    ExpandThreeDim() : ExpandDimsFixture("3") {}
+struct ExpandMinusOneDim : ExpandDimsFixture
+    ExpandMinusOneDim() : ExpandDimsFixture("-1") {}
+struct ExpandMinusThreeDim : ExpandDimsFixture
+    ExpandMinusThreeDim() : ExpandDimsFixture("-3") {}
+BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+                armnn::TensorShape({1, 2, 3, 5})));
+BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim)
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+                armnn::TensorShape({2, 3, 1, 5})));
+BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim)
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+                armnn::TensorShape({2, 3, 5, 1})));
+BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim)
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+                armnn::TensorShape({2, 3, 5, 1})));
+BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+                armnn::TensorShape({2, 1, 3, 5})));