IVGCVSW-6382 Add Unsqueeze operator support to ONNX parser

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ie0b68b08fc31444c58b0ffc9babdd456bbb51f35
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 67f8997..f28c2f7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -762,12 +762,14 @@
             src/armnnOnnxParser/test/FullyConnected.cpp
             src/armnnOnnxParser/test/Gather.cpp
             src/armnnOnnxParser/test/GetInputsOutputs.cpp
+            src/armnnOnnxParser/test/OnnxParserTestUtils.cpp
             src/armnnOnnxParser/test/OnnxParserTestUtils.hpp
             src/armnnOnnxParser/test/Pooling.cpp
             src/armnnOnnxParser/test/ProtoxtFixture.cpp
             src/armnnOnnxParser/test/Relu.cpp
             src/armnnOnnxParser/test/Reshape.cpp
             src/armnnOnnxParser/test/Shape.cpp
+            src/armnnOnnxParser/test/Unsqueeze.cpp
             )
     endif()
 
diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox
index 31b7687..689c062 100644
--- a/docs/01_01_parsers.dox
+++ b/docs/01_01_parsers.dox
@@ -76,6 +76,8 @@
 - Tanh
   - See the ONNX [Tanh documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Tanh) for more information.
 
+- Unsqueeze
+  - See the ONNX [Unsqueeze documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze) for more information.
 
 ### Partially supported
 
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index e70eb64..91ba52f 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -429,6 +429,7 @@
     { "Flatten",               &OnnxParserImpl::ParseFlatten },
     { "Shape",                 &OnnxParserImpl::ParseShape },
     { "Gather",                &OnnxParserImpl::ParseGather },
+    { "Unsqueeze",             &OnnxParserImpl::ParseUnsqueeze }
 };
 
 template<typename TypePair, typename Location>
@@ -1834,6 +1835,59 @@
     }
 }
 
+void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
+{
+    CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2);
+    CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
+
+    CHECK_VALID_DATATYPE(node.name(), node.input(0),
+                         m_TensorsInfo[node.input(0)].m_dtype,
+                         onnx::TensorProto::FLOAT); //input
+
+    TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+    std::vector<uint32_t> dims;
+    if (node.input_size() == 1 && node.attribute_size() > 0)
+    {
+        dims = ReadMandatoryNodeUint32ListAttribute(node, "axes");
+    }
+    else
+    {
+        CHECK_VALID_DATATYPE(node.name(), node.input(1),
+                             m_TensorsInfo[node.input(1)].m_dtype,
+                             onnx::TensorProto::INT64); //axes
+
+        auto int64Axes = m_TensorsInfo[node.input(1)].m_tensor->int64_data().data();
+        uint numDim = armnn::numeric_cast<uint>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
+
+        for(uint i = 0; i < numDim; i++)
+        {
+            uint32_t uint32Value = CHECKED_NON_NEGATIVE(CHECKED_INT32(int64Axes[i]));
+            dims.push_back(uint32Value);
+        }
+    }
+
+    // Ensure that the axes are sorted
+    std::sort(dims.begin(), dims.end());
+
+    std::vector<unsigned int> targetShape;
+
+    for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
+    {
+        targetShape.push_back(inputShape[i]);
+    }
+
+    for(uint i = 0; i < dims.size(); i++)
+    {
+        targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1);
+    }
+
+    auto outInfo = ComputeReshapeInfo(TensorShape(armnn::numeric_cast<unsigned int>(targetShape.size()),
+                                                  targetShape.data()), inputShape, node.output(0));
+    m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
+
+    CreateReshapeLayer(node.input(0), node.output(0), node.name());
+}
+
 void OnnxParserImpl::PrependForBroadcast(const std::string& outputName,
                                          const std::string& input0,
                                          const std::string& input1)
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index b71b8dc..196e903 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -121,6 +121,7 @@
     void ParseMaxPool(const onnx::NodeProto& nodeProto);
     void ParseShape(const onnx::NodeProto& node);
     void ParseReshape(const onnx::NodeProto& nodeProto);
+    void ParseUnsqueeze(const onnx::NodeProto& nodeProto);
 
     void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
     void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
diff --git a/src/armnnOnnxParser/test/OnnxParserTestUtils.cpp b/src/armnnOnnxParser/test/OnnxParserTestUtils.cpp
new file mode 100644
index 0000000..66c4013
--- /dev/null
+++ b/src/armnnOnnxParser/test/OnnxParserTestUtils.cpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "OnnxParserTestUtils.hpp"
+
+#include <fmt/format.h>
+
+namespace armnnUtils
+{
+
+std::string ConstructTensorShapeString(const std::vector<int>& shape)
+{
+    std::string shapeStr;
+    for (int i : shape)
+    {
+        shapeStr = fmt::format("{} dim {{ dim_value: {} }}", shapeStr, i);
+    }
+    return shapeStr;
+}
+
+std::string ConstructIntsAttribute(const std::string& name,
+                                   const std::vector<int>& values)
+{
+    std::string attrString = fmt::format("attribute {{ name: '{}'", name);;
+    for (int i : values)
+    {
+        attrString = fmt::format(" {} ints: {}", attrString, i);
+    }
+    attrString = fmt::format(" {} type: INTS }}", attrString);
+    return attrString;
+}
+
+} // namespace armnnUtils
\ No newline at end of file
diff --git a/src/armnnOnnxParser/test/OnnxParserTestUtils.hpp b/src/armnnOnnxParser/test/OnnxParserTestUtils.hpp
index 4ed6543..212cf59 100644
--- a/src/armnnOnnxParser/test/OnnxParserTestUtils.hpp
+++ b/src/armnnOnnxParser/test/OnnxParserTestUtils.hpp
@@ -5,17 +5,14 @@
 
 #pragma once
 
+#include <iostream>
+#include <vector>
+
 namespace armnnUtils
 {
 
-std::string ConstructTensorShapeString(const std::vector<int>& shape)
-{
-    std::string shapeStr;
-    for (int i : shape)
-    {
-        shapeStr = fmt::format("{} dim {{ dim_value: {} }}", shapeStr, i);
-    }
-    return shapeStr;
-}
+std::string ConstructTensorShapeString(const std::vector<int>& shape);
+
+std::string ConstructIntsAttribute(const std::string& name, const std::vector<int>& value);
 
 } // namespace armnnUtils
\ No newline at end of file
diff --git a/src/armnnOnnxParser/test/Unsqueeze.cpp b/src/armnnOnnxParser/test/Unsqueeze.cpp
new file mode 100644
index 0000000..95a191e
--- /dev/null
+++ b/src/armnnOnnxParser/test/Unsqueeze.cpp
@@ -0,0 +1,197 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "armnnOnnxParser/IOnnxParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+#include "OnnxParserTestUtils.hpp"
+
+TEST_SUITE("OnnxParser_Unsqueeze")
+{
+
+struct UnsqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    UnsqueezeFixture(const std::vector<int>& axes,
+                     const std::vector<int>& inputShape,
+                     const std::vector<int>& outputShape)
+    {
+        m_Prototext = R"(
+                    ir_version: 8
+                    producer_name: "onnx-example"
+                    graph {
+                      node {
+                        input: "Input"
+                        output: "Output"
+                        op_type: "Unsqueeze"
+                        )" + armnnUtils::ConstructIntsAttribute("axes", axes) + R"(
+                      }
+                      name: "test-model"
+                      input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                      output {
+                        name: "Output"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
+                            }
+                          }
+                        }
+                      }
+                    })";
+    }
+};
+
+struct UnsqueezeSingleAxesFixture : UnsqueezeFixture
+{
+    UnsqueezeSingleAxesFixture() : UnsqueezeFixture({ 0 }, { 2, 3 }, { 1, 2, 3 })
+    {
+        Setup();
+    }
+};
+
+struct UnsqueezeMultiAxesFixture : UnsqueezeFixture
+{
+    UnsqueezeMultiAxesFixture() : UnsqueezeFixture({ 1, 3 }, { 3, 2, 5 }, { 3, 1, 2, 1, 5 })
+    {
+        Setup();
+    }
+};
+
+struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture
+{
+    UnsqueezeUnsortedAxesFixture() : UnsqueezeFixture({ 3, 0, 1 }, { 2, 5 }, { 1, 1, 2, 1, 5 })
+    {
+        Setup();
+    }
+};
+
+TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest")
+{
+    RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
+                      {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
+}
+
+TEST_CASE_FIXTURE(UnsqueezeMultiAxesFixture, "UnsqueezeMultiAxesTest")
+{
+    RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                   6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                                   11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                                   16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
+                                   21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
+                                   26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
+                      {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
+                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
+                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
+}
+
+TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest")
+{
+    RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                   6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
+                      {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
+}
+
+struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+    UnsqueezeInputAxesFixture()
+    {
+        m_Prototext = R"(
+                    ir_version: 8
+                    producer_name: "onnx-example"
+                    graph {
+                      node {
+                        input: "Input"
+                        input: "Axes"
+                        output: "Output"
+                        op_type: "Unsqueeze"
+                      }
+                      initializer {
+                          dims: 2
+                          data_type: 7
+                          int64_data: 0
+                          int64_data: 3
+                          name: "Axes"
+                        }
+                      name: "test-model"
+                      input {
+                        name: "Input"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 5
+                              }
+                            }
+                          }
+                        }
+                      }
+                      output {
+                        name: "Output"
+                        type {
+                          tensor_type {
+                            elem_type: 1
+                            shape {
+                              dim {
+                                dim_value: 1
+                              }
+                              dim {
+                                dim_value: 3
+                              }
+                              dim {
+                                dim_value: 2
+                              }
+                              dim {
+                                dim_value: 1
+                              }
+                              dim {
+                                dim_value: 5
+                              }
+                            }
+                          }
+                        }
+                      }
+                    })";
+        Setup();
+    }
+};
+
+TEST_CASE_FIXTURE(UnsqueezeInputAxesFixture, "UnsqueezeInputAxesTest")
+{
+    RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                   6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                                   11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                                   16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
+                                   21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
+                                   26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
+                      {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
+                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
+                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
+                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
+}
+
+}