IVGCVSW-2987 Modify ParseSplit in TfLite parser

 * Allow input data with dimension not greater than 4D
 * Correct input order
 * Get split dimension from buffer data
 * Unit tests

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I285851b19e6fa7c715e5fe4853df167e7c856647
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 1ee4950..b7258b3 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1971,11 +1971,15 @@
     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(outputs.size(), numSplits);
 
-    armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
-    armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]);
+    armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[1]);
+    armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[0]);
 
-    // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
-    const unsigned int splitDim = static_cast<unsigned int>(axisTensorInfo.GetShape()[0]);
+    BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[0]->buffer);
+    std::vector<unsigned int> axisData(axisTensorInfo.GetNumElements());
+    ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
+
+    BOOST_ASSERT(axisTensorInfo.GetNumElements() == 1);
+    const unsigned int splitDim = axisData[0];
 
     // Armnn supports split along the channel dimension for data formats NHWC and NCHW.
     if (splitDim == 0 || splitDim == 2)
@@ -1989,13 +1993,13 @@
     }
 
     auto inputDimSize = inputTensorInfo.GetNumDimensions();
-    if (inputDimSize != MaxNumOfTensorDimensions)
+    if (inputDimSize > MaxNumOfTensorDimensions)
     {
         throw ParseException(
             boost::str(
                 boost::format(
                     "The number of dimensions: %1% for input tensors of the "
-                    "split op should be %2% %3%")
+                    "split op cannot be greater than %2% %3%")
                     % inputTensorInfo.GetNumDimensions()
                     % MaxNumOfTensorDimensions
                     % CHECK_LOCATION().AsString()));
@@ -2015,7 +2019,7 @@
     }
     splitterDimSizes[splitDim] /= numSplits;
 
-    SplitterDescriptor splitDesc(numSplits);
+    SplitterDescriptor splitDesc(numSplits, inputDimSize);
     for (unsigned int j = 0; j < numSplits; ++j)
     {
         // Set the size of the views.
@@ -2030,7 +2034,7 @@
     IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
 
     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
-    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[1]});
 
     TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()),
                                        splitterDimSizes.data());
diff --git a/src/armnnTfLiteParser/test/Split.cpp b/src/armnnTfLiteParser/test/Split.cpp
index 774a416..a687514 100644
--- a/src/armnnTfLiteParser/test/Split.cpp
+++ b/src/armnnTfLiteParser/test/Split.cpp
@@ -14,11 +14,12 @@
 
 struct SplitFixture : public ParserFlatbuffersFixture
 {
-    explicit SplitFixture(const std::string & inputShape,
-                          const std::string & axisShape,
-                          const std::string & numSplits,
-                          const std::string & outputShape1,
-                          const std::string & outputShape2)
+    explicit SplitFixture(const std::string& inputShape,
+                          const std::string& axisShape,
+                          const std::string& numSplits,
+                          const std::string& outputShape1,
+                          const std::string& outputShape2,
+                          const std::string& axisData)
     {
         m_JsonString = R"(
             {
@@ -75,12 +76,12 @@
                             }
                         }
                     ],
-                    "inputs": [ 0, 1 ],
+                    "inputs": [ 0 ],
                     "outputs": [ 2, 3 ],
                     "operators": [
                         {
                             "opcode_index": 0,
-                            "inputs": [ 0, 1 ],
+                            "inputs": [ 1, 0 ],
                             "outputs": [ 2, 3 ],
                             "builtin_options_type": "SplitOptions",
                             "builtin_options": {
@@ -90,7 +91,7 @@
                         }
                     ],
                 } ],
-                "buffers" : [ {}, {} ]
+                "buffers" : [ {}, {"data": )" + axisData + R"( }, {}, {} ]
             }
         )";
 
@@ -101,8 +102,8 @@
 
 struct SimpleSplitFixture : SplitFixture
 {
-    SimpleSplitFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ 1 ]", "2",
-        "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]")
+    SimpleSplitFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2",
+        "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]")
          {}
 };
 
@@ -113,14 +114,14 @@
         0,
         { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
                             11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
-        { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
-          {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f }}});
+        { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f } },
+          {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f } } });
 }
 
 struct SimpleSplitAxisThreeFixture : SplitFixture
 {
-    SimpleSplitAxisThreeFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ 3 ]", "2",
-        "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]")
+    SimpleSplitAxisThreeFixture() : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2",
+        "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]")
     {}
 };
 
@@ -130,8 +131,39 @@
         0,
         { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
                             11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
-        { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f }},
+        { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f } },
           {"outputTensor2", { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f } } } );
 }
 
+struct SimpleSplit2DFixture : SplitFixture
+{
+    SimpleSplit2DFixture() : SplitFixture( "[ 1, 8 ]", "[ ]", "2", "[ 1, 4 ]", "[ 1, 4 ]", "[ 1, 0, 0, 0 ]")
+    {}
+};
+
+BOOST_FIXTURE_TEST_CASE(SimpleSplit2D, SimpleSplit2DFixture)
+{
+    RunTest<2, armnn::DataType::Float32>(
+            0,
+            { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } } },
+            { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f } },
+              {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f } } } );
+}
+
+struct SimpleSplit3DFixture : SplitFixture
+{
+    SimpleSplit3DFixture() : SplitFixture( "[ 1, 8, 2 ]", "[ ]", "2", "[ 1, 4, 2 ]", "[ 1, 4, 2 ]", "[ 1, 0, 0, 0 ]")
+    {}
+};
+
+BOOST_FIXTURE_TEST_CASE(SimpleSplit3D, SimpleSplit3DFixture)
+{
+    RunTest<3, armnn::DataType::Float32>(
+            0,
+            { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
+                                10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
+            { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
+              {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } } );
+}
+
 BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file