COMPMID-3059: Add TF parser support for StridedSlice

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I31f25f26a50c9054b5650b1be127c84194b56be7
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index af86619..d65af23 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -368,6 +368,7 @@
     { "Softmax",               &TfParser::ParseSoftmax },
     { "Softplus",              &TfParser::ParseSoftplus },
     { "Split",                 &TfParser::ParseSplit },
+    { "StridedSlice",          &TfParser::ParseStridedSlice },
     { "Tanh",                  &TfParser::ParseTanh },
     { "MaxPool",               &TfParser::ParseMaxPool },
     { "AvgPool",               &TfParser::ParseAvgPool },
@@ -2760,6 +2761,54 @@
     return AddActivationLayer(nodeDef, activationDesc);
 }
 
+ParsedTfOperationPtr TfParser::ParseStridedSlice(const tensorflow::NodeDef& nodeDef,
+                                                 const tensorflow::GraphDef& graphDef)
+{
+    boost::ignore_unused(graphDef);
+
+    std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+    unsigned int numInputs = static_cast<unsigned int>(nodes.size());
+    std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
+
+    ParsedConstTfOperation<int32_t>* beginNode =
+            boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[1].m_IndexedValue);
+    std::vector<int32_t> beginTensorData;
+    beginNode->GetConstTensor(beginTensorData);
+
+    ParsedConstTfOperation<int32_t>* endNode =
+            boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[2].m_IndexedValue);
+    std::vector<int32_t> endTensorData;
+    endNode->GetConstTensor(endTensorData);
+
+    ParsedConstTfOperation<int32_t>* stridesNode =
+            boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[3].m_IndexedValue);
+    std::vector<int32_t> stridesTensorData;
+    stridesNode->GetConstTensor(stridesTensorData);
+
+    StridedSliceDescriptor desc;
+    desc.m_Begin = beginTensorData;
+    desc.m_End = endTensorData;
+    desc.m_Stride = stridesTensorData;
+    desc.m_BeginMask = ReadMandatoryNodeInt32Attribute(nodeDef, "begin_mask");
+    desc.m_EndMask = ReadMandatoryNodeInt32Attribute(nodeDef, "end_mask");
+    desc.m_EllipsisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "ellipsis_mask");
+    desc.m_NewAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "new_axis_mask");
+    desc.m_ShrinkAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "shrink_axis_mask");
+    desc.m_DataLayout = armnn::DataLayout::NHWC;
+    IConnectableLayer* const layer = m_Network->AddStridedSliceLayer(desc, nodeDef.name().c_str());
+
+    IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+    TensorInfo inputTensorInfo = prevLayerSlot.GetTensorInfo();
+
+    TensorInfo outputTensorInfo;
+    CalculateStridedSliceOutputTensorInfo(inputTensorInfo, desc, outputTensorInfo);
+
+    prevLayerSlot.Connect(layer->GetInputSlot(0));
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+    return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
 ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
 {
     boost::ignore_unused(graphDef);
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 8442ca0..a7d02be 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -156,6 +156,7 @@
     ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseSplit(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+    ParsedTfOperationPtr ParseStridedSlice(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
     ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp
index f40dc57..a6c20fd 100644
--- a/src/armnnTfParser/test/Gather.cpp
+++ b/src/armnnTfParser/test/Gather.cpp
@@ -12,9 +12,10 @@
 
 BOOST_AUTO_TEST_SUITE(TensorflowParser)
 
+namespace {
 // helper for setting the dimensions in prototxt
 void dimsHelper(const std::vector<int>& dims, std::string& text){
-    for(u_int i=0; i<dims.size(); ++i){
+    for(u_int i = 0; i < dims.size(); ++i) {
         text.append(R"(dim {
       size: )");
         text.append(std::to_string(dims[i]));
@@ -25,11 +26,11 @@
 
 // helper for converting from integer to octal representation
 void octalHelper(const std::vector<int>& indicesContent, std::string& text){
-    for (unsigned int i = 0; i < indicesContent.size(); ++i)
-    {
+    for(unsigned int i = 0; i < indicesContent.size(); ++i) {
         text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
     }
 }
+} // namespace
 
 struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
 {
diff --git a/src/armnnTfParser/test/StridedSlice.cpp b/src/armnnTfParser/test/StridedSlice.cpp
new file mode 100644
index 0000000..89faf75
--- /dev/null
+++ b/src/armnnTfParser/test/StridedSlice.cpp
@@ -0,0 +1,283 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "armnnTfParser/ITfParser.hpp"
+
+#include "ParserPrototxtFixture.hpp"
+#include <PrototxtConversions.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+namespace {
+// helper for setting the dimensions in prototxt
+void shapeHelper(const armnn::TensorShape& shape, std::string& text){
+    for(u_int i = 0; i < shape.GetNumDimensions(); ++i) {
+        text.append(R"(dim {
+      size: )");
+        text.append(std::to_string(shape[i]));
+        text.append(R"(
+    })");
+    }
+}
+
+// helper for converting from integer to octal representation
+void octalHelper(const std::vector<int>& content, std::string& text){
+    for (unsigned int i = 0; i < content.size(); ++i)
+    {
+        text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(content[i])));
+    }
+}
+} // namespace
+
+struct StridedSliceFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+    StridedSliceFixture(const armnn::TensorShape& inputShape,
+                        const std::vector<int>& beginData,
+                        const std::vector<int>& endData,
+                        const std::vector<int>& stridesData,
+                        int beginMask = 0,
+                        int endMask = 0,
+                        int ellipsisMask = 0,
+                        int newAxisMask = 0,
+                        int shrinkAxisMask = 0)
+    {
+        m_Prototext = R"(
+                         node {
+                           name: "input"
+                           op: "Placeholder"
+                           attr {
+                             key: "dtype"
+                             value {
+                               type: DT_FLOAT
+                             }
+                           }
+                           attr {
+                             key: "shape"
+                             value {
+                               shape {)";
+                                 shapeHelper(inputShape, m_Prototext);
+                                 m_Prototext.append(R"(
+                               }
+                             }
+                           }
+                         }
+                         node {
+                           name: "begin"
+                           op: "Const"
+                           attr {
+                             key: "dtype"
+                             value {
+                               type: DT_INT32
+                             }
+                           }
+                           attr {
+                             key: "value"
+                             value {
+                              tensor {
+                               dtype: DT_INT32
+                                 tensor_shape {
+                                   dim {
+                                    size: )");
+                                      m_Prototext += std::to_string(beginData.size());
+                                      m_Prototext.append(R"(
+                                    }
+                                 }
+                                 tensor_content: ")");
+                                   octalHelper(beginData, m_Prototext);
+                                   m_Prototext.append(R"("
+                               }
+                             }
+                           }
+                         }
+                         node {
+                           name: "end"
+                           op: "Const"
+                           attr {
+                             key: "dtype"
+                             value {
+                               type: DT_INT32
+                             }
+                           }
+                           attr {
+                             key: "value"
+                             value {
+                              tensor {
+                               dtype: DT_INT32
+                                 tensor_shape {
+                                   dim {
+                                    size: )");
+                                      m_Prototext += std::to_string(endData.size());
+                                      m_Prototext.append(R"(
+                                    }
+                                 }
+                                 tensor_content: ")");
+                                   octalHelper(endData, m_Prototext);
+                                   m_Prototext.append(R"("
+                               }
+                             }
+                           }
+                         }
+                         node {
+                           name: "strides"
+                           op: "Const"
+                           attr {
+                             key: "dtype"
+                             value {
+                               type: DT_INT32
+                             }
+                           }
+                           attr {
+                             key: "value"
+                             value {
+                              tensor {
+                               dtype: DT_INT32
+                                 tensor_shape {
+                                   dim {
+                                    size: )");
+                                      m_Prototext += std::to_string(stridesData.size());
+                                      m_Prototext.append(R"(
+                                    }
+                                 }
+                                 tensor_content: ")");
+                                   octalHelper(stridesData, m_Prototext);
+                                   m_Prototext.append(R"("
+                               }
+                             }
+                           }
+                         }
+                         node {
+                           name: "output"
+                           op: "StridedSlice"
+                           input: "input"
+                           input: "begin"
+                           input: "end"
+                           input: "strides"
+                           attr {
+                             key: "begin_mask"
+                             value {
+                               i: )");
+                               m_Prototext += std::to_string(beginMask);
+                               m_Prototext.append(R"(
+                             }
+                           }
+                           attr {
+                             key: "end_mask"
+                             value {
+                               i: )");
+                                 m_Prototext += std::to_string(endMask);
+                                 m_Prototext.append(R"(
+                             }
+                           }
+                           attr {
+                             key: "ellipsis_mask"
+                             value {
+                               i: )");
+                                 m_Prototext += std::to_string(ellipsisMask);
+                                 m_Prototext.append(R"(
+                             }
+                           }
+                           attr {
+                             key: "new_axis_mask"
+                             value {
+                               i: )");
+                                 m_Prototext += std::to_string(newAxisMask);
+                                 m_Prototext.append(R"(
+                             }
+                           }
+                           attr {
+                             key: "shrink_axis_mask"
+                             value {
+                               i: )");
+                                 m_Prototext += std::to_string(shrinkAxisMask);
+                                 m_Prototext.append(R"(
+                             }
+                           }
+                         })");
+
+        Setup({ { "input", inputShape } }, { "output" });
+    }
+};
+
+struct StridedSlice4DFixture : StridedSliceFixture
+{
+    StridedSlice4DFixture() : StridedSliceFixture({ 3, 2, 3, 1 },  // inputShape
+                                                  { 1, 0, 0, 0 },  // beginData
+                                                  { 2, 2, 3, 1 },  // endData
+                                                  { 1, 1, 1, 1 }   // stridesData
+    ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
+{
+    RunTest<4>(
+            {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+                         3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+                         5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+            {{"output", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
+}
+
+struct StridedSlice4DReverseFixture : StridedSliceFixture
+{
+
+    StridedSlice4DReverseFixture() : StridedSliceFixture({ 3, 2, 3, 1 },   // inputShape
+                                                         { 1, -1, 0, 0 },  // beginData
+                                                         { 2, -3, 3, 1 },  // endData
+                                                         { 1, -1, 1, 1 }   // stridesData
+    ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
+{
+    RunTest<4>(
+            {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+                         3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+                         5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+            {{"output", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
+}
+
+struct StridedSliceSimpleStrideFixture : StridedSliceFixture
+{
+    StridedSliceSimpleStrideFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
+                                                            { 0, 0, 0, 0 }, // beginData
+                                                            { 3, 2, 3, 1 }, // endData
+                                                            { 2, 2, 2, 1 }  // stridesData
+    ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
+{
+    RunTest<4>(
+            {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+                         3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+                         5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+            {{"output", { 1.0f, 1.0f,
+                          5.0f, 5.0f }}});
+}
+
+struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
+{
+    StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
+                                                               { 1, 1, 1, 1 }, // beginData
+                                                               { 1, 1, 1, 1 }, // endData
+                                                               { 1, 1, 1, 1 }, // stridesData
+                                                               (1 << 4) - 1,   // beginMask
+                                                               (1 << 4) - 1    // endMask
+    ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
+{
+    RunTest<4>(
+            {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+                         3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+                         5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+            {{"output", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
+}
+
+BOOST_AUTO_TEST_SUITE_END()