IVGCVSW-2683 Add Serializer & Deserializer for Constant

Change-Id: Iad7d89dfa963d9015cbe044f67aecc8bf6634b10
Signed-off-by: Conor Kennedy <conor.kennedy@arm.com>
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index dc14069..db70e7b 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -92,7 +92,8 @@
     DepthwiseConvolution2d = 8,
     Activation = 9,
     Permute = 10,
-    FullyConnected = 11
+    FullyConnected = 11,
+    Constant = 12
 }
 
 // Base layer table to be used as part of other layers
@@ -125,6 +126,11 @@
     base:LayerBase;
 }
 
+table ConstantLayer {
+    base:LayerBase;
+    input:ConstTensor;
+}
+
 table Convolution2dLayer {
     base:LayerBase;
     descriptor:Convolution2dDescriptor;
@@ -251,6 +257,7 @@
 union Layer {
     ActivationLayer,
     AdditionLayer,
+    ConstantLayer,
     Convolution2dLayer,
     DepthwiseConvolution2dLayer,
     FullyConnectedLayer,
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index a0a640e..b8f5c3b 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -141,6 +141,25 @@
     CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
 }
 
+// Build FlatBuffer for Constant Layer
+void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
+                                           const armnn::ConstTensor& input,
+                                           const char* name)
+{
+    // Create FlatBuffer BaseLayer
+    auto flatBufferConstantBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Constant);
+
+    auto flatBufferConstTensorInfo = CreateConstTensorInfo(input);
+
+    // Create the FlatBuffer ConstantLayer
+    auto flatBufferLayer = CreateConstantLayer(m_flatBufferBuilder,
+                                               flatBufferConstantBaseLayer,
+                                               flatBufferConstTensorInfo);
+
+    // Add the AnyLayer to the FlatBufferLayers
+    CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConstantLayer);
+}
+
 // Build FlatBuffer for Convolution2dLayer
 void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
                                                 const armnn::Convolution2dDescriptor& descriptor,
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index a423aa8..781648f 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -49,6 +49,10 @@
     void VisitAdditionLayer(const armnn::IConnectableLayer* layer,
                             const char* name = nullptr) override;
 
+    void VisitConstantLayer(const armnn::IConnectableLayer* layer,
+                            const armnn::ConstTensor& input,
+                            const char* = nullptr) override;
+
     void VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
                                  const armnn::Convolution2dDescriptor& descriptor,
                                  const armnn::ConstTensor& weights,
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index 18e9f53..9e84b63 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -8,6 +8,7 @@
 
 * Activation
 * Addition
+* Constant
 * Convolution2d
 * DepthwiseConvolution2d
 * FullyConnected
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index bb05052..4e90dbe 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -174,6 +174,71 @@
     deserializedNetwork->Accept(nameChecker);
 }
 
+BOOST_AUTO_TEST_CASE(SerializeConstant)
+{
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+
+    armnn::ConstTensor inputTensor;
+
+    armnn::IConnectableLayer* const inputLayer0 = network->AddConstantLayer(inputTensor, "constant");
+    armnn::IConnectableLayer* const outputLayer0 = network->AddOutputLayer(0);
+
+    inputLayer0->GetOutputSlot(0).Connect(outputLayer0->GetInputSlot(0));
+
+    armnnSerializer::Serializer serializer;
+    serializer.Serialize(*network);
+
+    std::stringstream stream;
+    serializer.SaveSerializedToStream(stream);
+    BOOST_TEST(stream.str().length() > 0);
+    BOOST_TEST(stream.str().find("constant") != stream.str().npos);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeConstant)
+{
+    class VerifyConstantName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+    {
+    public:
+        void VisitConstantLayer(const armnn::IConnectableLayer*, const armnn::ConstTensor&, const char* name) override
+        {
+            BOOST_TEST(name == "constant");
+        }
+    };
+
+    armnn::TensorInfo commonTensorInfo({ 2, 3 }, armnn::DataType::Float32);
+
+    std::vector<float> constantData = GenerateRandomData<float>(commonTensorInfo.GetNumElements());
+    armnn::ConstTensor constTensor(commonTensorInfo, constantData);
+
+    // Builds up the structure of the network.
+    armnn::INetworkPtr net(armnn::INetwork::Create());
+
+    armnn::IConnectableLayer* input = net->AddInputLayer(0);
+    armnn::IConnectableLayer* constant = net->AddConstantLayer(constTensor, "constant");
+    armnn::IConnectableLayer* add = net->AddAdditionLayer();
+    armnn::IConnectableLayer* output = net->AddOutputLayer(0);
+
+    input->GetOutputSlot(0).Connect(add->GetInputSlot(0));
+    constant->GetOutputSlot(0).Connect(add->GetInputSlot(1));
+    add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+    // Sets the tensors in the network.
+    input->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
+    constant->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
+    add->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*net));
+    BOOST_CHECK(deserializedNetwork);
+
+    VerifyConstantName nameChecker;
+    deserializedNetwork->Accept(nameChecker);
+
+    CheckDeserializedNetworkAgainstOriginal(*net,
+                                            *deserializedNetwork,
+                                            commonTensorInfo.GetShape(),
+                                            commonTensorInfo.GetShape());
+}
+
 BOOST_AUTO_TEST_CASE(SerializeMultiplication)
 {
     class VerifyMultiplicationName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>