Build graph->inputIds/outputIds with layerBindingId instead of layerIndex

Signed-off-by: Jung Tae-young tee.ty.jung@openedges.com
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: I25ceeca70e72fad88ab039aed5a5ab6a7cc08c6c
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 99ee0b5..3bbd71a 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -768,6 +768,40 @@
             CHECK_LOCATION().AsString()));
 }
 
+unsigned int Deserializer::GetInputLayerInVector(GraphPtr graph, int targetId)
+{
+    for (unsigned int i = 0; i < graph->layers()->size(); i++)
+    {
+        auto layer = graph->layers()->Get(i);
+        if (layer->layer_type() == Layer::Layer_InputLayer)
+        {
+            auto layerBindingId = layer->layer_as_InputLayer()->base()->layerBindingId();
+            if (layerBindingId == targetId)
+            {
+                return i;
+            }
+        }
+    }
+    throw ParseException("Input layer with given layerBindingId not found");
+}
+
+unsigned int Deserializer::GetOutputLayerInVector(GraphPtr graph, int targetId)
+{
+    for (unsigned int i = 0; i < graph->layers()->size(); i++)
+    {
+        auto layer = graph->layers()->Get(i);
+        if (layer->layer_type() == Layer::Layer_OutputLayer)
+        {
+            auto layerBindingId = layer->layer_as_OutputLayer()->base()->layerBindingId();
+            if (layerBindingId == targetId)
+            {
+                return i;
+            }
+        }
+    }
+    throw ParseException("Output layer with given layerBindingId not found");
+}
+
 unsigned int Deserializer::GetLayerIndexInVector(GraphPtr graph, unsigned int targetIndex)
 {
     for (unsigned int i = 0; i < graph->layers()->size(); i++)
@@ -781,6 +815,18 @@
     throw ParseException("Layer with given index not found");
 }
 
+Deserializer::FeatureVersions Deserializer::GetFeatureVersions(GraphPtr graph)
+{
+    Deserializer::FeatureVersions versions;
+
+    if (graph->featureVersions())
+    {
+        versions.m_BindingIdScheme = graph->featureVersions()->bindingIdsScheme();
+    }
+
+    return versions;
+}
+
 void Deserializer::SetupInputLayers(GraphPtr graph)
 {
     CHECK_GRAPH(graph, 0);
@@ -790,8 +836,18 @@
 
     for (unsigned int i = 0; i < numInputs; i++)
     {
-        const unsigned int inputId = graph->inputIds()->Get(i);
-        const unsigned int inputLayerIndex = GetLayerIndexInVector(graph, inputId);
+        unsigned int inputLayerIndex = 0xFFFFFFFF;
+        if (GetFeatureVersions(graph).m_BindingIdScheme == 0)
+        {
+            const unsigned int inputId = boost::numeric_cast<unsigned int>(graph->inputIds()->Get(i));
+            inputLayerIndex = GetLayerIndexInVector(graph, inputId);
+        }
+        else
+        {
+            const int inputId = graph->inputIds()->Get(i);
+            inputLayerIndex = GetInputLayerInVector(graph, inputId);
+        }
+
         LayerBaseRawPtr baseLayer = GetBaseLayer(graph, inputLayerIndex);
 
         // GetBindingLayerInfo expect the index to be index in the vector not index property on each layer base
@@ -819,8 +875,18 @@
 
     for (unsigned int i = 0; i < numOutputs; i++)
     {
-        const unsigned int outputId = graph->outputIds()->Get(i);
-        const unsigned int outputLayerIndex = GetLayerIndexInVector(graph, outputId);
+        unsigned int outputLayerIndex = 0xFFFFFFFF;
+        if (GetFeatureVersions(graph).m_BindingIdScheme == 0)
+        {
+            const unsigned int outputId = boost::numeric_cast<unsigned int>(graph->outputIds()->Get(i));
+            outputLayerIndex = GetLayerIndexInVector(graph, outputId);
+        }
+        else
+        {
+            const int outputId = graph->outputIds()->Get(i);
+            outputLayerIndex = GetOutputLayerInVector(graph, outputId);
+        }
+
         LayerBaseRawPtr baseLayer = GetBaseLayer(graph, outputLayerIndex);
 
         // GetBindingLayerInfo expect the index to be index in the vector not index property on each layer base
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index ae8be6e..8e8fe1a 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -144,9 +144,21 @@
     void SetupInputLayers(GraphPtr graphPtr);
     void SetupOutputLayers(GraphPtr graphPtr);
 
+    /// Helper to get the index of the layer in the flatbuffer vector from its bindingId property
+    unsigned int GetInputLayerInVector(GraphPtr graph, int targetId);
+    unsigned int GetOutputLayerInVector(GraphPtr graph, int targetId);
+
     /// Helper to get the index of the layer in the flatbuffer vector from its index property
     unsigned int GetLayerIndexInVector(GraphPtr graph, unsigned int index);
 
+    struct FeatureVersions
+    {
+        // Default values to zero for backward compatibility
+        unsigned int m_BindingIdScheme = 0;
+    };
+
+    FeatureVersions GetFeatureVersions(GraphPtr graph);
+
     /// The network we're building. Gets cleared after it is passed to the user
     armnn::INetworkPtr                    m_Network;
     std::vector<LayerParsingFunction>     m_ParserFunctions;
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 0f8a816..be6616d 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -826,11 +826,16 @@
     layer:Layer;
 }
 
+table FeatureCompatibilityVersions {
+  bindingIdsScheme:uint = 0;
+}
+
 // Root type for serialized data is the graph of the network
 table SerializedGraph {
     layers:[AnyLayer];
-    inputIds:[uint];
-    outputIds:[uint];
+    inputIds:[int];
+    outputIds:[int];
+    featureVersions:FeatureCompatibilityVersions;
 }
 
 root_type SerializedGraph;
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 13ea0f0..b43f26c 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -91,8 +91,8 @@
     auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
                                                                                 flatBufferInputBaseLayer,
                                                                                 id);
-    // Push layer index to outputIds.
-    m_inputIds.push_back(GetSerializedId(layer->GetGuid()));
+    // Push layer binding id to outputIds.
+    m_inputIds.push_back(id);
 
     // Create the FlatBuffer InputLayer
     auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
@@ -113,8 +113,8 @@
     auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
                                                                                  flatBufferOutputBaseLayer,
                                                                                  id);
-    // Push layer index to outputIds.
-    m_outputIds.push_back(GetSerializedId(layer->GetGuid()));
+    // Push layer binding id to outputIds.
+    m_outputIds.push_back(id);
 
     // Create the FlatBuffer OutputLayer
     auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
@@ -1449,6 +1449,16 @@
     return flatBufferConstTensor;
 }
 
+flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> SerializerVisitor::GetVersionTable()
+{
+    flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> versionsTable =
+        serializer::CreateFeatureCompatibilityVersions(
+                m_flatBufferBuilder,
+                1 // Binding ids scheme version
+            );
+    return versionsTable;
+}
+
 std::vector<fb::Offset<serializer::InputSlot>>
     SerializerVisitor::CreateInputSlots(const armnn::IConnectableLayer* layer)
 {
@@ -1531,7 +1541,8 @@
         fbBuilder,
         fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
         fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
-        fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()));
+        fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()),
+        m_SerializerVisitor.GetVersionTable());
 
     // Serialize the graph
     fbBuilder.Finish(serializedGraph);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index d92c93d..14d2776 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -29,12 +29,12 @@
         return m_flatBufferBuilder;
     }
 
-    std::vector<uint32_t>& GetInputIds()
+    std::vector<int>& GetInputIds()
     {
         return m_inputIds;
     }
 
-    std::vector<uint32_t>& GetOutputIds()
+    std::vector<int>& GetOutputIds()
     {
         return m_outputIds;
     }
@@ -44,6 +44,9 @@
         return m_serializedLayers;
     }
 
+    flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> GetVersionTable();
+
+
     ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead")
     void VisitAbsLayer(const armnn::IConnectableLayer* layer,
                        const char* name = nullptr) override;
@@ -301,11 +304,11 @@
     /// AnyLayers required by the SerializedGraph.
     std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>> m_serializedLayers;
 
-    /// Vector of indexes of all Input Layers required by the SerializedGraph.
-    std::vector<uint32_t> m_inputIds;
+    /// Vector of the binding ids of all Input Layers required by the SerializedGraph.
+    std::vector<int> m_inputIds;
 
-    /// Vector of indexes of all Output Layers required by the SerializedGraph.
-    std::vector<uint32_t> m_outputIds;
+    /// Vector of the binding ids of all Output Layers required by the SerializedGraph.
+    std::vector<int> m_outputIds;
 
     /// Mapped Guids of all Layers to match our index.
     std::unordered_map<armnn::LayerGuid, uint32_t > m_guidMap;