IVGCVSW-2685 Serialize / de-serialize the DepthwiseConvolution2d layer

Change-Id: I37e360c824b30cb14cbef86f6ff7636bc9382109
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
diff --git a/src/armnnSerializer/Schema.fbs b/src/armnnSerializer/Schema.fbs
index cbc7da0..6c542b1 100644
--- a/src/armnnSerializer/Schema.fbs
+++ b/src/armnnSerializer/Schema.fbs
@@ -75,7 +75,8 @@
     Pooling2d = 4,
     Reshape = 5,
     Softmax = 6,
-    Convolution2d = 7
+    Convolution2d = 7,
+    DepthwiseConvolution2d = 8
 }
 
 // Base layer table to be used as part of other layers
@@ -168,6 +169,24 @@
     beta:float;
 }
 
+table DepthwiseConvolution2dLayer {
+    base:LayerBase;
+    descriptor:DepthwiseConvolution2dDescriptor;
+    weights:ConstTensor;
+    biases:ConstTensor;
+}
+
+table DepthwiseConvolution2dDescriptor {
+    padLeft:uint;
+    padRight:uint;
+    padTop:uint;
+    padBottom:uint;
+    strideX:uint;
+    strideY:uint;
+    biasEnabled:bool = false;
+    dataLayout:DataLayout = NCHW;
+}
+
 table OutputLayer {
     base:BindableLayerBase;
 }
@@ -184,6 +203,7 @@
 union Layer {
     AdditionLayer,
     Convolution2dLayer,
+    DepthwiseConvolution2dLayer,
     InputLayer,
     MultiplicationLayer,
     OutputLayer,
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index d6a23cc..27204a0 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -129,6 +129,39 @@
     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer);
 }
 
+void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
+                                                         const DepthwiseConvolution2dDescriptor& descriptor,
+                                                         const ConstTensor& weights,
+                                                         const Optional<ConstTensor>& biases,
+                                                         const char* name)
+{
+    auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthwiseConvolution2d);
+    auto fbDescriptor = CreateDepthwiseConvolution2dDescriptor(m_flatBufferBuilder,
+                                                               descriptor.m_PadLeft,
+                                                               descriptor.m_PadRight,
+                                                               descriptor.m_PadTop,
+                                                               descriptor.m_PadBottom,
+                                                               descriptor.m_StrideX,
+                                                               descriptor.m_StrideY,
+                                                               descriptor.m_BiasEnabled,
+                                                               GetFlatBufferDataLayout(descriptor.m_DataLayout));
+
+    flatbuffers::Offset<serializer::ConstTensor> fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
+    flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
+    if (biases.has_value())
+    {
+        fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
+    }
+
+    auto flatBufferLayer = CreateDepthwiseConvolution2dLayer(m_flatBufferBuilder,
+                                                             fbBaseLayer,
+                                                             fbDescriptor,
+                                                             fbWeightsConstTensorInfo,
+                                                             fbBiasesConstTensorInfo);
+
+    CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DepthwiseConvolution2dLayer);
+}
+
 // Build FlatBuffer for Multiplication Layer
 void SerializerVisitor::VisitMultiplicationLayer(const IConnectableLayer* layer, const char* name)
 {
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index fd1a792..907d4ed 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -51,20 +51,22 @@
                                  const armnn::Optional<armnn::ConstTensor>& biases,
                                  const char* = nullptr) override;
 
+    void VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
+                                          const armnn::DepthwiseConvolution2dDescriptor& descriptor,
+                                          const armnn::ConstTensor& weights,
+                                          const armnn::Optional<armnn::ConstTensor>& biases,
+                                          const char* name = nullptr) override;
+
     void VisitInputLayer(const armnn::IConnectableLayer* layer,
                          armnn::LayerBindingId id,
                          const char* name = nullptr) override;
 
-    void VisitOutputLayer(const armnn::IConnectableLayer* layer,
-                          armnn::LayerBindingId id,
-                          const char* name = nullptr) override;
-
     void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
                                   const char* name = nullptr) override;
 
-    void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
-                           const armnn::SoftmaxDescriptor& softmaxDescriptor,
-                           const char* name = nullptr) override;
+    void VisitOutputLayer(const armnn::IConnectableLayer* layer,
+                          armnn::LayerBindingId id,
+                          const char* name = nullptr) override;
 
     void VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
                              const armnn::Pooling2dDescriptor& pooling2dDescriptor,
@@ -74,6 +76,10 @@
                            const armnn::ReshapeDescriptor& reshapeDescriptor,
                            const char* name = nullptr) override;
 
+    void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
+                           const armnn::SoftmaxDescriptor& softmaxDescriptor,
+                           const char* name = nullptr) override;
+
 private:
 
     /// Creates the Input Slots and Output Slots and LayerBase for the layer.
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 31ef045..a88193d 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -10,7 +10,7 @@
 
 #include <armnnDeserializeParser/IDeserializeParser.hpp>
 
-#include <numeric>
+#include <random>
 #include <sstream>
 #include <vector>
 
@@ -40,11 +40,99 @@
     return serializerString;
 }
 
+template<typename DataType>
+static std::vector<DataType> GenerateRandomData(size_t size)
+{
+    constexpr bool isIntegerType = std::is_integral<DataType>::value;
+    using Distribution =
+        typename std::conditional<isIntegerType,
+                                  std::uniform_int_distribution<DataType>,
+                                  std::uniform_real_distribution<DataType>>::type;
+
+    static constexpr DataType lowerLimit = std::numeric_limits<DataType>::min();
+    static constexpr DataType upperLimit = std::numeric_limits<DataType>::max();
+
+    static Distribution distribution(lowerLimit, upperLimit);
+    static std::default_random_engine generator;
+
+    std::vector<DataType> randomData(size);
+    std::generate(randomData.begin(), randomData.end(), []() { return distribution(generator); });
+
+    return randomData;
+}
+
+void CheckDeserializedNetworkAgainstOriginal(const armnn::INetwork& deserializedNetwork,
+                                             const armnn::INetwork& originalNetwork,
+                                             const armnn::TensorShape& inputShape,
+                                             const armnn::TensorShape& outputShape,
+                                             armnn::LayerBindingId inputBindingId = 0,
+                                             armnn::LayerBindingId outputBindingId = 0)
+{
+    armnn::IRuntime::CreationOptions options;
+    armnn::IRuntimePtr runtime = armnn::IRuntime::Create(options);
+
+    std::vector<armnn::BackendId> preferredBackends = { armnn::BackendId("CpuRef") };
+
+    // Optimize original network
+    armnn::IOptimizedNetworkPtr optimizedOriginalNetwork =
+        armnn::Optimize(originalNetwork, preferredBackends, runtime->GetDeviceSpec());
+    BOOST_CHECK(optimizedOriginalNetwork);
+
+    // Optimize deserialized network
+    armnn::IOptimizedNetworkPtr optimizedDeserializedNetwork =
+        armnn::Optimize(deserializedNetwork, preferredBackends, runtime->GetDeviceSpec());
+    BOOST_CHECK(optimizedDeserializedNetwork);
+
+    armnn::NetworkId networkId1;
+    armnn::NetworkId networkId2;
+
+    // Load original and deserialized network
+    armnn::Status status1 = runtime->LoadNetwork(networkId1, std::move(optimizedOriginalNetwork));
+    BOOST_CHECK(status1 == armnn::Status::Success);
+
+    armnn::Status status2 = runtime->LoadNetwork(networkId2, std::move(optimizedDeserializedNetwork));
+    BOOST_CHECK(status2 == armnn::Status::Success);
+
+    // Generate some input data
+    std::vector<float> inputData = GenerateRandomData<float>(inputShape.GetNumElements());
+
+    armnn::InputTensors inputTensors1
+    {
+         { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId1, inputBindingId), inputData.data()) }
+    };
+
+    armnn::InputTensors inputTensors2
+    {
+         { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId2, inputBindingId), inputData.data()) }
+    };
+
+    std::vector<float> outputData1(outputShape.GetNumElements());
+    std::vector<float> outputData2(outputShape.GetNumElements());
+
+    armnn::OutputTensors outputTensors1
+    {
+         { 0, armnn::Tensor(runtime->GetOutputTensorInfo(networkId1, outputBindingId), outputData1.data()) }
+    };
+
+    armnn::OutputTensors outputTensors2
+    {
+         { 0, armnn::Tensor(runtime->GetOutputTensorInfo(networkId2, outputBindingId), outputData2.data()) }
+    };
+
+    // Run original and deserialized network
+    runtime->EnqueueWorkload(networkId1, inputTensors1, outputTensors1);
+    runtime->EnqueueWorkload(networkId2, inputTensors2, outputTensors2);
+
+    // Compare output data
+    BOOST_CHECK_EQUAL_COLLECTIONS(outputData1.begin(), outputData1.end(),
+                                  outputData2.begin(), outputData2.end());
+}
+
 } // anonymous namespace
 
 BOOST_AUTO_TEST_SUITE(SerializerTests)
 
-BOOST_AUTO_TEST_CASE(SimpleNetworkSerialization)
+BOOST_AUTO_TEST_CASE(SerializeAddition)
 {
     armnn::INetworkPtr network = armnn::INetwork::Create();
     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
@@ -65,88 +153,7 @@
     BOOST_TEST(stream.str().length() > 0);
 }
 
-BOOST_AUTO_TEST_CASE(Conv2dSerialization)
-{
-    armnn::IRuntime::CreationOptions options; // default options
-    armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
-
-    armnnDeserializeParser::IDeserializeParserPtr parser = armnnDeserializeParser::IDeserializeParser::Create();
-
-    armnn::TensorInfo inputInfo(armnn::TensorShape({1, 5, 5, 1}), armnn::DataType::Float32, 1.0f, 0);
-    armnn::TensorInfo outputInfo(armnn::TensorShape({1, 3, 3, 1}), armnn::DataType::Float32, 4.0f, 0);
-
-    armnn::TensorInfo weightsInfo(armnn::TensorShape({1, 3, 3, 1}), armnn::DataType::Float32, 2.0f, 0);
-
-    std::vector<float> weightsData({4, 5, 6, 0, 0, 0, 3, 2, 1});
-
-    // Construct network
-    armnn::INetworkPtr network = armnn::INetwork::Create();
-
-    armnn::Convolution2dDescriptor descriptor;
-    descriptor.m_PadLeft = 1;
-    descriptor.m_PadRight = 1;
-    descriptor.m_PadTop = 1;
-    descriptor.m_PadBottom = 1;
-    descriptor.m_StrideX = 2;
-    descriptor.m_StrideY = 2;
-    descriptor.m_BiasEnabled = false;
-    descriptor.m_DataLayout = armnn::DataLayout::NHWC;
-
-    armnn::ConstTensor weights(weightsInfo, weightsData);
-
-    armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0, "input");
-    armnn::IConnectableLayer* const convLayer   = network->AddConvolution2dLayer(descriptor, weights, "conv");
-    armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output");
-
-    inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
-    inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
-
-    convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
-    convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
-
-    armnnSerializer::Serializer serializer;
-    serializer.Serialize(*network);
-
-    std::stringstream stream;
-    serializer.SaveSerializedToStream(stream);
-
-    std::string const serializerString{stream.str()};
-    std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
-
-    armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
-
-    auto deserializedOptimized = Optimize(*deserializedNetwork, {armnn::Compute::CpuRef}, run->GetDeviceSpec());
-
-    armnn::NetworkId networkIdentifier;
-
-    // Load graph into runtime
-    run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
-
-    std::vector<float> inputData
-    {
-            1, 5, 2, 3, 5, 8, 7, 3, 6, 3, 3, 3, 9, 1, 9, 4, 1, 8, 1, 3, 6, 8, 1, 9, 2
-    };
-    armnn::InputTensors inputTensors
-    {
-            {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), inputData.data())}
-    };
-
-    std::vector<float> expectedOutputData
-    {
-            23, 33, 24, 91, 99, 48, 26, 50, 19
-    };
-
-    std::vector<float> outputData(9);
-    armnn::OutputTensors outputTensors
-    {
-            {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
-    };
-    run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
-    BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(), outputData.end(),
-                                  expectedOutputData.begin(), expectedOutputData.end());
-}
-
-BOOST_AUTO_TEST_CASE(SimpleNetworkWithMultiplicationSerialization)
+BOOST_AUTO_TEST_CASE(SerializeMultiplication)
 {
     const armnn::TensorInfo info({ 1, 5, 2, 3 }, armnn::DataType::Float32);
 
@@ -172,14 +179,57 @@
     BOOST_TEST(stream.str().find(multLayerName) != stream.str().npos);
 }
 
-BOOST_AUTO_TEST_CASE(SimpleReshapeIntegration)
+BOOST_AUTO_TEST_CASE(SerializeDeserializeConvolution2d)
 {
-    armnn::NetworkId networkIdentifier;
-    armnn::IRuntime::CreationOptions options; // default options
-    armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
+    armnn::TensorInfo inputInfo ({ 1, 5, 5, 1 }, armnn::DataType::Float32);
+    armnn::TensorInfo outputInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32);
 
-    unsigned int inputShape[] = {1, 9};
-    unsigned int outputShape[] = {3, 3};
+    armnn::TensorInfo weightsInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32);
+    armnn::TensorInfo biasesInfo ({ 1 }, armnn::DataType::Float32);
+
+    // Construct network
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+
+    armnn::Convolution2dDescriptor descriptor;
+    descriptor.m_PadLeft     = 1;
+    descriptor.m_PadRight    = 1;
+    descriptor.m_PadTop      = 1;
+    descriptor.m_PadBottom   = 1;
+    descriptor.m_StrideX     = 2;
+    descriptor.m_StrideY     = 2;
+    descriptor.m_BiasEnabled = true;
+    descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
+
+    std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
+    armnn::ConstTensor weights(weightsInfo, weightsData);
+
+    std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements());
+    armnn::ConstTensor biases(biasesInfo, biasesData);
+
+    armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0, "input");
+    armnn::IConnectableLayer* const convLayer   =
+        network->AddConvolution2dLayer(descriptor, weights, biases, "convolution");
+    armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output");
+
+    inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
+    inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+
+    convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+    convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
+
+    CheckDeserializedNetworkAgainstOriginal(*network,
+                                            *deserializedNetwork,
+                                            inputInfo.GetShape(),
+                                            outputInfo.GetShape());
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeReshape)
+{
+    unsigned int inputShape[]  = { 1, 9 };
+    unsigned int outputShape[] = { 3, 3 };
 
     auto inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::DataType::Float32);
     auto outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32);
@@ -198,49 +248,62 @@
     reshapeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
     reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
 
-    armnnSerializer::Serializer serializer;
-    serializer.Serialize(*network);
-    std::stringstream stream;
-    serializer.SaveSerializedToStream(stream);
-    std::string const serializerString{stream.str()};
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
 
-    //Deserialize network.
-    auto deserializedNetwork = DeserializeNetwork(serializerString);
-
-    //Optimize the deserialized network
-    auto deserializedOptimized = Optimize(*deserializedNetwork, {armnn::Compute::CpuRef},
-                                          run->GetDeviceSpec());
-
-    // Load graph into runtime
-    run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
-
-    std::vector<float> input1Data(inputTensorInfo.GetNumElements());
-    std::iota(input1Data.begin(), input1Data.end(), 8);
-
-    armnn::InputTensors inputTensors
-    {
-         {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), input1Data.data())}
-    };
-
-    std::vector<float> outputData(input1Data.size());
-    armnn::OutputTensors outputTensors
-    {
-         {0,armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
-    };
-
-    run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
-
-    BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(),outputData.end(), input1Data.begin(),input1Data.end());
+    CheckDeserializedNetworkAgainstOriginal(*network,
+                                            *deserializedNetwork,
+                                            inputTensorInfo.GetShape(),
+                                            outputTensorInfo.GetShape());
 }
 
-BOOST_AUTO_TEST_CASE(SimpleSoftmaxIntegration)
+BOOST_AUTO_TEST_CASE(SerializeDeserializeDepthwiseConvolution2d)
+{
+    armnn::TensorInfo inputInfo ({ 1, 5, 5, 3 }, armnn::DataType::Float32);
+    armnn::TensorInfo outputInfo({ 1, 3, 3, 3 }, armnn::DataType::Float32);
+
+    armnn::TensorInfo weightsInfo({ 1, 3, 3, 3 }, armnn::DataType::Float32);
+    armnn::TensorInfo biasesInfo ({ 3 }, armnn::DataType::Float32);
+
+    armnn::DepthwiseConvolution2dDescriptor descriptor;
+    descriptor.m_StrideX     = 1;
+    descriptor.m_StrideY     = 1;
+    descriptor.m_BiasEnabled = true;
+    descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
+
+    std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
+    armnn::ConstTensor weights(weightsInfo, weightsData);
+
+    std::vector<int32_t> biasesData = GenerateRandomData<int32_t>(biasesInfo.GetNumElements());
+    armnn::ConstTensor biases(biasesInfo, biasesData);
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+    armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+    armnn::IConnectableLayer* const depthwiseConvLayer =
+        network->AddDepthwiseConvolution2dLayer(descriptor, weights, biases, "depthwiseConv");
+    armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+    inputLayer->GetOutputSlot(0).Connect(depthwiseConvLayer->GetInputSlot(0));
+    inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+    depthwiseConvLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+    depthwiseConvLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
+
+    CheckDeserializedNetworkAgainstOriginal(*network,
+                                            *deserializedNetwork,
+                                            inputInfo.GetShape(),
+                                            outputInfo.GetShape());
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeSoftmax)
 {
     armnn::TensorInfo tensorInfo({1, 10}, armnn::DataType::Float32);
 
     armnn::SoftmaxDescriptor descriptor;
     descriptor.m_Beta = 1.0f;
 
-    // Create test network
     armnn::INetworkPtr network = armnn::INetwork::Create();
     armnn::IConnectableLayer* const inputLayer   = network->AddInputLayer(0);
     armnn::IConnectableLayer* const softmaxLayer = network->AddSoftmaxLayer(descriptor, "softmax");
@@ -251,71 +314,22 @@
     softmaxLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
     softmaxLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
 
-    // Serialize & deserialize network
     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
     BOOST_CHECK(deserializedNetwork);
 
-    armnn::IRuntime::CreationOptions options;
-    armnn::IRuntimePtr runtime = armnn::IRuntime::Create(options);
-
-    armnn::IOptimizedNetworkPtr optimizedNetwork =
-        armnn::Optimize(*network, {armnn::Compute::CpuRef}, runtime->GetDeviceSpec());
-    BOOST_CHECK(optimizedNetwork);
-
-    armnn::IOptimizedNetworkPtr deserializedOptimizedNetwork =
-        armnn::Optimize(*deserializedNetwork, {armnn::Compute::CpuRef}, runtime->GetDeviceSpec());
-    BOOST_CHECK(deserializedOptimizedNetwork);
-
-    armnn::NetworkId networkId1;
-    armnn::NetworkId networkId2;
-
-    runtime->LoadNetwork(networkId1, std::move(optimizedNetwork));
-    runtime->LoadNetwork(networkId2, std::move(deserializedOptimizedNetwork));
-
-    std::vector<float> inputData(tensorInfo.GetNumElements());
-    std::iota(inputData.begin(), inputData.end(), 0);
-
-    armnn::InputTensors inputTensors1
-    {
-         {0, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId1, 0), inputData.data())}
-    };
-
-    armnn::InputTensors inputTensors2
-    {
-         {0, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId2, 0), inputData.data())}
-    };
-
-    std::vector<float> outputData1(inputData.size());
-    std::vector<float> outputData2(inputData.size());
-
-    armnn::OutputTensors outputTensors1
-    {
-         {0, armnn::Tensor(runtime->GetOutputTensorInfo(networkId1, 0), outputData1.data())}
-    };
-
-    armnn::OutputTensors outputTensors2
-    {
-         {0, armnn::Tensor(runtime->GetOutputTensorInfo(networkId2, 0), outputData2.data())}
-    };
-
-    runtime->EnqueueWorkload(networkId1, inputTensors1, outputTensors1);
-    runtime->EnqueueWorkload(networkId2, inputTensors2, outputTensors2);
-
-    BOOST_CHECK_EQUAL_COLLECTIONS(outputData1.begin(), outputData1.end(),
-                                  outputData2.begin(), outputData2.end());
+    CheckDeserializedNetworkAgainstOriginal(*network,
+                                            *deserializedNetwork,
+                                            tensorInfo.GetShape(),
+                                            tensorInfo.GetShape());
 }
 
-BOOST_AUTO_TEST_CASE(SimplePooling2dIntegration)
+BOOST_AUTO_TEST_CASE(SerializeDeserializePooling2d)
 {
-    armnn::NetworkId networkIdentifier;
-    armnn::IRuntime::CreationOptions options; // default options
-    armnn::IRuntimePtr runtime = armnn::IRuntime::Create(options);
-
     unsigned int inputShape[]  = {1, 2, 2, 1};
     unsigned int outputShape[] = {1, 1, 1, 1};
 
-    auto inputTensorInfo  = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32);
-    auto outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::Float32);
+    auto inputInfo  = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32);
+    auto outputInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::Float32);
 
     armnn::Pooling2dDescriptor desc;
     desc.m_DataLayout          = armnn::DataLayout::NHWC;
@@ -337,36 +351,17 @@
     armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
 
     inputLayer->GetOutputSlot(0).Connect(pooling2dLayer->GetInputSlot(0));
-    inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+    inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
     pooling2dLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
-    pooling2dLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+    pooling2dLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
 
-    auto deserializeNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
 
-    //Optimize the deserialized network
-    auto deserializedOptimized = Optimize(*deserializeNetwork, {armnn::Compute::CpuRef},
-                                          runtime->GetDeviceSpec());
-
-    // Load graph into runtime
-    runtime->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
-
-    std::vector<float> input1Data(inputTensorInfo.GetNumElements());
-    std::iota(input1Data.begin(), input1Data.end(), 4);
-
-    armnn::InputTensors inputTensors
-    {
-          {0, armnn::ConstTensor(runtime->GetInputTensorInfo(networkIdentifier, 0), input1Data.data())}
-    };
-
-    std::vector<float> outputData(input1Data.size());
-    armnn::OutputTensors outputTensors
-    {
-           {0, armnn::Tensor(runtime->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
-    };
-
-    runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
-
-    BOOST_CHECK_EQUAL(outputData[0], 5.5);
+    CheckDeserializedNetworkAgainstOriginal(*network,
+                                            *deserializedNetwork,
+                                            inputInfo.GetShape(),
+                                            outputInfo.GetShape());
 }
 
-BOOST_AUTO_TEST_SUITE_END()
+BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file