IVGCVSW-2709 Serialize / de-serialize the Splitter layer

* fixed typo in Ref Merger Workload comment
* fixed typo in ViewsDescriptor comment
* made the origins descriptor accessable in the ViewsDescriptor
  (needed for serialization)
* based the unit test on the use of the splitter in the CaffeParser

Change-Id: I3e716839adb4eee5a695633377b49e7e18ec2aa9
Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com>
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index 43f41a7..a6339cf 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -290,6 +290,11 @@
     return m_ViewSizes ? m_ViewSizes[idx] : nullptr;
 }
 
+const OriginsDescriptor& ViewsDescriptor::GetOrigins() const
+{
+    return m_Origins;
+}
+
 void swap(OriginsDescriptor& first, OriginsDescriptor& second)
 {
     using std::swap;
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 719e47e..ba12c37 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -18,6 +18,7 @@
 #include <boost/assert.hpp>
 #include <boost/format.hpp>
 #include <boost/log/trivial.hpp>
+#include <boost/polymorphic_cast.hpp>
 
 // The generated code based on the Serialize schema:
 #include <ArmnnSchema_generated.h>
@@ -213,6 +214,7 @@
     m_ParserFunctions[Layer_RsqrtLayer]                  = &Deserializer::ParseRsqrt;
     m_ParserFunctions[Layer_SoftmaxLayer]                = &Deserializer::ParseSoftmax;
     m_ParserFunctions[Layer_SpaceToBatchNdLayer]         = &Deserializer::ParseSpaceToBatchNd;
+    m_ParserFunctions[Layer_SplitterLayer]               = &Deserializer::ParseSplitter;
     m_ParserFunctions[Layer_StridedSliceLayer]           = &Deserializer::ParseStridedSlice;
     m_ParserFunctions[Layer_SubtractionLayer]            = &Deserializer::ParseSubtraction;
 }
@@ -283,6 +285,8 @@
             return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
         case Layer::Layer_SpaceToBatchNdLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_SpaceToBatchNdLayer()->base();
+        case Layer::Layer_SplitterLayer:
+            return graphPtr->layers()->Get(layerIndex)->layer_as_SplitterLayer()->base();
         case Layer::Layer_StridedSliceLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base();
         case Layer::Layer_SubtractionLayer:
@@ -1831,4 +1835,48 @@
     RegisterOutputSlots(graph, layerIndex, layer);
 }
 
+void Deserializer::ParseSplitter(GraphPtr graph, unsigned int layerIndex)
+{
+    CHECK_LAYERS(graph, 0, layerIndex);
+
+    Deserializer::TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
+    CHECK_VALID_SIZE(inputs.size(), 1);
+
+    Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
+
+    auto flatBufferViewsDescriptor = graph->layers()->Get(layerIndex)->layer_as_SplitterLayer()->descriptor();
+    auto flatBufferViewSizes = flatBufferViewsDescriptor->viewSizes();
+    auto flatBufferOriginsDescriptor = flatBufferViewsDescriptor->origins();
+    auto flatBufferViewOrigins = flatBufferOriginsDescriptor->viewOrigins();
+    uint32_t numViews = flatBufferOriginsDescriptor->numViews();
+    uint32_t numDimensions = flatBufferOriginsDescriptor->numDimensions();
+
+    // Check numViews and numDimensions corresponds to the ones already serialized ...
+    // numViews ==  flatBufferViewSizes.size();
+    // foreach: numDimensions == flatBufferViewSizes[x].size();
+
+    armnn::ViewsDescriptor viewsDescriptor(numViews, numDimensions);
+    for(unsigned int vIdx = 0; vIdx < numViews; ++vIdx)
+    {
+        for (unsigned int dIdx = 0; dIdx < numDimensions; ++dIdx)
+        {
+            viewsDescriptor.SetViewSize(vIdx, dIdx, flatBufferViewSizes->Get(vIdx)->data()->Get(dIdx));
+            viewsDescriptor.SetViewOriginCoord(vIdx, dIdx, flatBufferViewOrigins->Get(vIdx)->data()->Get(dIdx));
+        }
+    }
+
+    auto layerName = GetLayerName(graph, layerIndex);
+    IConnectableLayer* layer = m_Network->AddSplitterLayer(viewsDescriptor, layerName.c_str());
+
+    // I could have as many outputs as views ...
+    for(unsigned int vIdx = 0; vIdx < numViews; ++vIdx)
+    {
+        armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[vIdx]);
+        layer->GetOutputSlot(vIdx).SetTensorInfo(outputTensorInfo);
+    }
+
+    RegisterInputSlots(graph, layerIndex, layer);
+    RegisterOutputSlots(graph, layerIndex, layer);
+}
+
 } // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 3006481..67e6e84 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -101,6 +101,7 @@
     void ParseRsqrt(GraphPtr graph, unsigned int layerIndex);
     void ParseSoftmax(GraphPtr graph, unsigned int layerIndex);
     void ParseSpaceToBatchNd(GraphPtr graph, unsigned int layerIndex);
+    void ParseSplitter(GraphPtr graph, unsigned int layerIndex);
     void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
     void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
 
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index ceeae59..398489b 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -34,6 +34,7 @@
 * Rsqrt
 * Softmax
 * SpaceToBatchNd
+* Splitter
 * StridedSlice
 * Subtraction
 
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 36389b7..40ee7a5 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -112,7 +112,8 @@
     Gather = 28,
     Mean = 29,
     Merger = 30,
-    L2Normalization = 31
+    L2Normalization = 31,
+    Splitter = 32
 }
 
 // Base layer table to be used as part of other layers
@@ -442,6 +443,16 @@
    viewOrigins:[UintVector];
 }
 
+table ViewsDescriptor {
+   origins:OriginsDescriptor;
+   viewSizes:[UintVector];
+}
+
+table SplitterLayer {
+   base:LayerBase;
+   descriptor:ViewsDescriptor;
+}
+
 union Layer {
     ActivationLayer,
     AdditionLayer,
@@ -474,7 +485,8 @@
     GatherLayer,
     MeanLayer,
     MergerLayer,
-    L2NormalizationLayer
+    L2NormalizationLayer,
+    SplitterLayer
 }
 
 table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index eaf19d5..3774c25 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -634,6 +634,71 @@
     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer);
 }
 
+// Build FlatBuffer for Splitter Layer
+void SerializerVisitor::VisitSplitterLayer(const armnn::IConnectableLayer* layer,
+                                           const armnn::ViewsDescriptor& viewsDescriptor,
+                                           const char* name)
+{
+    // Create FlatBuffer ViewOrigins
+    std::vector<flatbuffers::Offset<UintVector>> flatBufferViewOrigins;
+    flatBufferViewOrigins.reserve(viewsDescriptor.GetNumViews());
+
+    for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
+    {
+        std::vector<uint32_t> viewOrigin;
+        viewOrigin.reserve(viewsDescriptor.GetNumDimensions());
+
+        // Copy vector
+        for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
+        {
+            viewOrigin.push_back(viewsDescriptor.GetViewOrigin(vIdx)[dIdx]);
+        }
+
+        flatBufferViewOrigins.push_back(CreateUintVector(m_flatBufferBuilder,
+                                                         m_flatBufferBuilder.CreateVector(viewOrigin)));
+    }
+
+    // Create FlatBuffer OriginsDescriptor
+    auto flatBufferOriginDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
+                                                              viewsDescriptor.GetOrigins().GetConcatAxis(),
+                                                              viewsDescriptor.GetOrigins().GetNumViews(),
+                                                              viewsDescriptor.GetOrigins().GetNumDimensions(),
+                                                              m_flatBufferBuilder.CreateVector(flatBufferViewOrigins));
+
+    // Create FlatBuffer ViewOrigins
+    std::vector<flatbuffers::Offset<UintVector>> flatBufferViewSizes;
+    flatBufferViewSizes.reserve(viewsDescriptor.GetNumViews());
+
+    for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
+    {
+        std::vector<uint32_t> viewSize;
+        viewSize.reserve(viewsDescriptor.GetNumDimensions());
+
+        // Copy vector
+        for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
+        {
+            viewSize.push_back(viewsDescriptor.GetViewSizes(vIdx)[dIdx]);
+        }
+
+        flatBufferViewSizes.push_back(CreateUintVector(m_flatBufferBuilder,
+                                                       m_flatBufferBuilder.CreateVector(viewSize)));
+    }
+
+    // Create FlatBuffer ViewsDescriptor
+    auto flatBufferViewsDescriptor = CreateViewsDescriptor(m_flatBufferBuilder,
+                                                           flatBufferOriginDescriptor,
+                                                           m_flatBufferBuilder.CreateVector(flatBufferViewSizes));
+
+    // Create FlatBuffer BaseLayer
+    auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Splitter);
+
+    auto flatBufferSplitterLayer = serializer::CreateSplitterLayer(m_flatBufferBuilder,
+                                                                   flatBufferBaseLayer,
+                                                                   flatBufferViewsDescriptor);
+
+    CreateAnyLayer(flatBufferSplitterLayer.o, serializer::Layer::Layer_SplitterLayer);
+}
+
 void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
                                                 const armnn::NormalizationDescriptor& descriptor,
                                                 const char* name)
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index c0e70c9..cb05792 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -162,6 +162,10 @@
                                  const armnn::NormalizationDescriptor& normalizationDescriptor,
                                  const char* name = nullptr) override;
 
+    void VisitSplitterLayer(const armnn::IConnectableLayer* layer,
+                            const armnn::ViewsDescriptor& viewsDescriptor,
+                            const char* name = nullptr) override;
+
     void VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
                                 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
                                 const char* name = nullptr) override;
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index d557756..81d2faa 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -34,6 +34,7 @@
 * Rsqrt
 * Softmax
 * SpaceToBatchNd
+* Splitter
 * StridedSlice
 * Subtraction
 
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 069b9d6..41f5d14 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -473,8 +473,8 @@
     VerifyL2NormalizationName nameChecker(l2NormLayerName);
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal<float>(*network,
-                                                   *deserializedNetwork,
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
                                                    { info.GetShape() },
                                                    { info.GetShape() });
 }
@@ -1520,4 +1520,79 @@
                                                    {0, 1});
 }
 
+BOOST_AUTO_TEST_CASE(SerializeDeserializeSplitter)
+{
+    class VerifySplitterName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+    {
+    public:
+        void VisitSplitterLayer(const armnn::IConnectableLayer*,
+                                const armnn::ViewsDescriptor& viewsDescriptor,
+                                const char* name) override
+        {
+            BOOST_TEST(name == "splitter");
+        }
+    };
+
+    unsigned int numViews = 3;
+    unsigned int numDimensions = 4;
+    unsigned int inputShape[] = {1,18, 4, 4};
+    unsigned int outputShape[] = {1, 6, 4, 4};
+
+    auto inputTensorInfo = armnn::TensorInfo(numDimensions, inputShape, armnn::DataType::Float32);
+    auto outputTensorInfo = armnn::TensorInfo(numDimensions, outputShape, armnn::DataType::Float32);
+
+    // This is modelled on how the caffe parser sets up a splitter layer to partition an input
+    // along dimension one.
+    unsigned int splitterDimSizes[4] = {static_cast<unsigned int>(inputShape[0]),
+                                        static_cast<unsigned int>(inputShape[1]),
+                                        static_cast<unsigned int>(inputShape[2]),
+                                        static_cast<unsigned int>(inputShape[3])};
+    splitterDimSizes[1] /= numViews;
+    armnn::ViewsDescriptor desc(numViews, numDimensions);
+
+    for (unsigned int g = 0; g < numViews; ++g)
+    {
+        desc.SetViewOriginCoord(g, 1, splitterDimSizes[1] * g);
+
+        // Set the size of the views.
+        for (unsigned int dimIdx=0; dimIdx < 4; dimIdx++)
+        {
+                desc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]);
+        }
+    }
+
+    const char* splitterLayerName = "splitter";
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+    armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+    armnn::IConnectableLayer* const splitterLayer = network->AddSplitterLayer(desc, splitterLayerName);
+    armnn::IConnectableLayer* const outputLayer0 = network->AddOutputLayer(0);
+    armnn::IConnectableLayer* const outputLayer1 = network->AddOutputLayer(1);
+    armnn::IConnectableLayer* const outputLayer2 = network->AddOutputLayer(2);
+
+    inputLayer->GetOutputSlot(0).Connect(splitterLayer->GetInputSlot(0));
+    inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+    splitterLayer->GetOutputSlot(0).Connect(outputLayer0->GetInputSlot(0));
+    splitterLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+    splitterLayer->GetOutputSlot(1).Connect(outputLayer1->GetInputSlot(0));
+    splitterLayer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo);
+    splitterLayer->GetOutputSlot(2).Connect(outputLayer2->GetInputSlot(0));
+    splitterLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
+
+    VerifySplitterName nameChecker;
+    deserializedNetwork->Accept(nameChecker);
+
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputTensorInfo.GetShape()},
+                                                   {outputTensorInfo.GetShape(),
+                                                    outputTensorInfo.GetShape(),
+                                                    outputTensorInfo.GetShape()},
+                                                   {0},
+                                                   {0, 1, 2});
+}
+
 BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/Merger.cpp b/src/backends/reference/workloads/Merger.cpp
index 10cc249..8877ee2 100644
--- a/src/backends/reference/workloads/Merger.cpp
+++ b/src/backends/reference/workloads/Merger.cpp
@@ -21,7 +21,7 @@
     if (sourceInfo.GetQuantizationScale() != destInfo.GetQuantizationScale() ||
         sourceInfo.GetQuantizationOffset() != destInfo.GetQuantizationOffset())
     {
-        // Dequantize value acording to sourceInfo params
+        // Dequantize value according to sourceInfo params
         float dequantizedValue = armnn::Dequantize<uint8_t>(source,
                                                             sourceInfo.GetQuantizationScale(),
                                                             sourceInfo.GetQuantizationOffset());