IVGCVSW-2691 Add Serialize/Deseralize Gather layer

Change-Id: I589d37c9f65801b701858d6e68e2e3151fac6e16
Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 185bdad..1c934a8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -605,6 +605,7 @@
                 src/armnnDeserializer/test/DeserializeEqual.cpp
                 src/armnnDeserializer/test/DeserializeFloor.cpp
                 src/armnnDeserializer/test/DeserializeFullyConnected.cpp
+                src/armnnDeserializer/test/DeserializeGather.cpp
                 src/armnnDeserializer/test/DeserializeGreater.cpp
                 src/armnnDeserializer/test/DeserializeMultiplication.cpp
                 src/armnnDeserializer/test/DeserializeNormalization.cpp
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index e4c9bd6..405d95e 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -196,6 +196,7 @@
     m_ParserFunctions[Layer_EqualLayer]                  = &Deserializer::ParseEqual;
     m_ParserFunctions[Layer_FullyConnectedLayer]         = &Deserializer::ParseFullyConnected;
     m_ParserFunctions[Layer_FloorLayer]                  = &Deserializer::ParseFloor;
+    m_ParserFunctions[Layer_GatherLayer]                 = &Deserializer::ParseGather;
     m_ParserFunctions[Layer_GreaterLayer]                = &Deserializer::ParseGreater;
     m_ParserFunctions[Layer_MinimumLayer]                = &Deserializer::ParseMinimum;
     m_ParserFunctions[Layer_MaximumLayer]                = &Deserializer::ParseMaximum;
@@ -241,6 +242,8 @@
             return graphPtr->layers()->Get(layerIndex)->layer_as_FullyConnectedLayer()->base();
         case Layer::Layer_FloorLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_FloorLayer()->base();
+        case Layer::Layer_GatherLayer:
+            return graphPtr->layers()->Get(layerIndex)->layer_as_GatherLayer()->base();
         case Layer::Layer_GreaterLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_GreaterLayer()->base();
         case Layer::Layer_InputLayer:
@@ -468,7 +471,7 @@
 }
 
 Deserializer::TensorRawPtrVector Deserializer::GetInputs(const GraphPtr& graphPtr,
-                                                                   unsigned int layerIndex)
+                                                         unsigned int layerIndex)
 {
     CHECK_LAYERS(graphPtr, 0, layerIndex);
     auto layer = GetBaseLayer(graphPtr, layerIndex);
@@ -611,7 +614,7 @@
 }
 
 BindingPointInfo Deserializer::GetNetworkInputBindingInfo(unsigned int layerIndex,
-                                                               const std::string& name) const
+                                                          const std::string& name) const
 {
     for (auto inputBinding : m_InputBindings)
     {
@@ -1710,4 +1713,27 @@
     RegisterOutputSlots(graph, layerIndex, layer);
 }
 
+void Deserializer::ParseGather(GraphPtr graph, unsigned int layerIndex)
+{
+    CHECK_LAYERS(graph, 0, layerIndex);
+
+    Deserializer::TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
+    CHECK_VALID_SIZE(inputs.size(), 2);
+
+    Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
+    CHECK_VALID_SIZE(outputs.size(), 1);
+
+    auto layerName = GetLayerName(graph, layerIndex);
+
+    IConnectableLayer* layer = m_Network->AddGatherLayer(layerName.c_str());
+
+    armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+    RegisterInputSlots(graph, layerIndex, layer);
+    RegisterOutputSlots(graph, layerIndex, layer);
+
+}
+
 } // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index b45551f..0580025 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -82,6 +82,7 @@
     void ParseEqual(GraphPtr graph, unsigned int layerIndex);
     void ParseFloor(GraphPtr graph, unsigned int layerIndex);
     void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex);
+    void ParseGather(GraphPtr graph, unsigned int layerIndex);
     void ParseGreater(GraphPtr graph, unsigned int layerIndex);
     void ParseMinimum(GraphPtr graph, unsigned int layerIndex);
     void ParseMaximum(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index a295676..cc7c626 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -17,6 +17,7 @@
 * Equal
 * Floor
 * FullyConnected
+* Gather
 * Greater
 * Maximum
 * Minimum
diff --git a/src/armnnDeserializer/test/DeserializeGather.cpp b/src/armnnDeserializer/test/DeserializeGather.cpp
new file mode 100644
index 0000000..3fdcf51
--- /dev/null
+++ b/src/armnnDeserializer/test/DeserializeGather.cpp
@@ -0,0 +1,157 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersSerializeFixture.hpp"
+#include "../Deserializer.hpp"
+
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(Deserializer)
+
+struct GatherFixture : public ParserFlatbuffersSerializeFixture
+{
+    explicit GatherFixture(const std::string &inputShape,
+                           const std::string &indicesShape,
+                           const std::string &input1Content,
+                           const std::string &outputShape,
+                           const std::string dataType,
+                           const std::string constDataType)
+    {
+        m_JsonString = R"(
+        {
+                inputIds: [0],
+                outputIds: [3],
+                layers: [
+                {
+                    layer_type: "InputLayer",
+                    layer: {
+                          base: {
+                                layerBindingId: 0,
+                                base: {
+                                    index: 0,
+                                    layerName: "InputLayer",
+                                    layerType: "Input",
+                                    inputSlots: [{
+                                        index: 0,
+                                        connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+                                    }],
+                                    outputSlots: [ {
+                                        index: 0,
+                                        tensorInfo: {
+                                            dimensions: )" + inputShape + R"(,
+                                            dataType: )" + dataType + R"(
+                                            }}]
+                                    }
+                    }}},
+                    {
+                    layer_type: "ConstantLayer",
+                        layer: {
+                               base: {
+                                  index:1,
+                                  layerName: "ConstantLayer",
+                                  layerType: "Constant",
+                                   outputSlots: [ {
+                                    index: 0,
+                                    tensorInfo: {
+                                        dimensions: )" + indicesShape + R"(,
+                                        dataType: "Signed32",
+                                    },
+                                  }],
+                              },
+                              input: {
+                              info: {
+                                       dimensions: )" + indicesShape + R"(,
+                                       dataType: )" + dataType + R"(
+                                   },
+                              data_type: )" + constDataType + R"(,
+                              data: {
+                                  data: )" + input1Content + R"(,
+                                    } }
+                                },},
+                    {
+                    layer_type: "GatherLayer",
+                        layer: {
+                              base: {
+                                   index: 2,
+                                   layerName: "GatherLayer",
+                                   layerType: "Gather",
+                                   inputSlots: [
+                                   {
+                                       index: 0,
+                                       connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+                                   },
+                                   {
+                                        index: 1,
+                                        connection: {sourceLayerIndex:1, outputSlotIndex:0 }
+                                   }],
+                                   outputSlots: [ {
+                                          index: 0,
+                                          tensorInfo: {
+                                               dimensions: )" + outputShape + R"(,
+                                               dataType: )" + dataType + R"(
+
+                                   }}]}
+                        }},
+                    {
+                    layer_type: "OutputLayer",
+                    layer: {
+                        base:{
+                              layerBindingId: 0,
+                              base: {
+                                    index: 3,
+                                    layerName: "OutputLayer",
+                                    layerType: "Output",
+                                    inputSlots: [{
+                                        index: 0,
+                                        connection: {sourceLayerIndex:2, outputSlotIndex:0 },
+                                    }],
+                                    outputSlots: [ {
+                                        index: 0,
+                                        tensorInfo: {
+                                            dimensions: )" + outputShape + R"(,
+                                            dataType: )" + dataType + R"(
+                                        },
+                                }],
+                            }}},
+                }]
+                 } )";
+
+        Setup();
+    }
+};
+
+struct SimpleGatherFixtureFloat32 : GatherFixture
+{
+    SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
+                                                 "[ 2, 3, 2, 3 ]", "Float32", "IntData") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
+{
+    RunTest<4, armnn::DataType::Float32>(0,
+                                         {{"InputLayer", {  1,  2,  3,
+                                                            4,  5,  6,
+                                                            7,  8,  9,
+                                                            10, 11, 12,
+                                                            13, 14, 15,
+                                                            16, 17, 18 }}},
+                                         {{"OutputLayer", { 7,  8,  9,
+                                                            10, 11, 12,
+                                                            13, 14, 15,
+                                                            16, 17, 18,
+                                                            7,  8,  9,
+                                                            10, 11, 12,
+                                                            13, 14, 15,
+                                                            16, 17, 18,
+                                                            7,  8,  9,
+                                                            10, 11, 12,
+                                                            1,  2,  3,
+                                                            4,  5,  6 }}});
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index ed3de83..ac32e66 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -108,7 +108,8 @@
     Greater = 24,
     ResizeBilinear = 25,
     Subtraction = 26,
-    StridedSlice = 27
+    StridedSlice = 27,
+    Gather = 28
 }
 
 // Base layer table to be used as part of other layers
@@ -188,6 +189,10 @@
     transposeWeightsMatrix:bool = false;
 }
 
+table GatherLayer {
+    base:LayerBase;
+}
+
 table GreaterLayer {
     base:LayerBase;
 }
@@ -427,7 +432,8 @@
     GreaterLayer,
     ResizeBilinearLayer,
     SubtractionLayer,
-    StridedSliceLayer
+    StridedSliceLayer,
+    GatherLayer
 }
 
 table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 9653908..38e815d 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -329,6 +329,16 @@
     CreateAnyLayer(fbMaximumLayer.o, serializer::Layer::Layer_MaximumLayer);
 }
 
+void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
+{
+    auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
+
+    auto flatBufferLayer = CreateGatherLayer(m_flatBufferBuilder,
+                                             fbBaseLayer);
+
+    CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
+}
+
 void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name)
 {
     auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 0dcacc8..1b1a3e9 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -92,6 +92,9 @@
                                   const armnn::Optional<armnn::ConstTensor>& biases,
                                   const char* name = nullptr) override;
 
+    void VisitGatherLayer(const armnn::IConnectableLayer* layer,
+                          const char* name = nullptr) override;
+                          
     void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
                            const char* name = nullptr) override;
 
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index 713b35f..24a764a 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -17,6 +17,7 @@
 * Equal
 * Floor
 * FullyConnected
+* Gather
 * Greater
 * Maximum
 * Minimum
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 5a6806a..2751aff 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -61,12 +61,13 @@
     return randomData;
 }
 
+template<typename T>
 void CheckDeserializedNetworkAgainstOriginal(const armnn::INetwork& deserializedNetwork,
                                              const armnn::INetwork& originalNetwork,
                                              const std::vector<armnn::TensorShape>& inputShapes,
                                              const std::vector<armnn::TensorShape>& outputShapes,
-                                             const std::vector<armnn::LayerBindingId>& inputBindingIds = {0},
-                                             const std::vector<armnn::LayerBindingId>& outputBindingIds = {0})
+                                             const std::vector<armnn::LayerBindingId>& inputBindingIds={0},
+                                             const std::vector<armnn::LayerBindingId>& outputBindingIds={0})
 {
     BOOST_CHECK(inputShapes.size() == inputBindingIds.size());
     BOOST_CHECK(outputShapes.size() == outputBindingIds.size());
@@ -99,12 +100,11 @@
     // Generate some input data
     armnn::InputTensors inputTensors1;
     armnn::InputTensors inputTensors2;
-    std::vector<std::vector<float>> inputData;
+    std::vector<std::vector<T>> inputData;
     inputData.reserve(inputShapes.size());
 
-    for (unsigned int i = 0; i < inputShapes.size(); i++)
-    {
-        inputData.push_back(GenerateRandomData<float>(inputShapes[i].GetNumElements()));
+    for (unsigned int i = 0; i < inputShapes.size(); i++) {
+        inputData.push_back(GenerateRandomData<T>(inputShapes[i].GetNumElements()));
 
         inputTensors1.emplace_back(
             i, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId1, inputBindingIds[i]), inputData[i].data()));
@@ -115,8 +115,8 @@
 
     armnn::OutputTensors outputTensors1;
     armnn::OutputTensors outputTensors2;
-    std::vector<std::vector<float>> outputData1;
-    std::vector<std::vector<float>> outputData2;
+    std::vector<std::vector<T>> outputData1;
+    std::vector<std::vector<T>> outputData2;
     outputData1.reserve(outputShapes.size());
     outputData2.reserve(outputShapes.size());
 
@@ -249,10 +249,10 @@
     VerifyConstantName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*net,
-                                            *deserializedNetwork,
-                                            {commonTensorInfo.GetShape()},
-                                            {commonTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *net,
+                                                   {commonTensorInfo.GetShape()},
+                                                   {commonTensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeFloor)
@@ -502,10 +502,10 @@
     VerifyConvolution2dName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputInfo.GetShape()},
-                                            {outputInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputInfo.GetShape()},
+                                                   {outputInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeGreater)
@@ -542,11 +542,11 @@
     VerifyGreaterName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo1.GetShape(), inputTensorInfo2.GetShape()},
-                                            {outputTensorInfo.GetShape()},
-                                            {0, 1});
+    CheckDeserializedNetworkAgainstOriginal<float>(*network,
+                                                   *deserializedNetwork,
+                                                   {inputTensorInfo1.GetShape(), inputTensorInfo2.GetShape()},
+                                                   {outputTensorInfo.GetShape()},
+                                                   {0, 1});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeReshape)
@@ -586,10 +586,10 @@
     VerifyReshapeName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputTensorInfo.GetShape()},
+                                                   {outputTensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeDepthwiseConvolution2d)
@@ -642,10 +642,10 @@
     VerifyDepthwiseConvolution2dName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputInfo.GetShape()},
-                                            {outputInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputInfo.GetShape()},
+                                                   {outputInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeSoftmax)
@@ -680,10 +680,10 @@
     VerifySoftmaxName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {tensorInfo.GetShape()},
-                                            {tensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {tensorInfo.GetShape()},
+                                                   {tensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializePooling2d)
@@ -732,10 +732,10 @@
     VerifyPooling2dName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputInfo.GetShape()},
-                                            {outputInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputInfo.GetShape()},
+                                                   {outputInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializePermute)
@@ -774,10 +774,10 @@
     VerifyPermuteName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputTensorInfo.GetShape()},
+                                                   {outputTensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeFullyConnected)
@@ -831,10 +831,10 @@
     VerifyFullyConnectedName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputInfo.GetShape()},
-                                            {outputInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputInfo.GetShape()},
+                                                   {outputInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeSpaceToBatchNd)
@@ -877,56 +877,57 @@
     VerifySpaceToBatchNdName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputTensorInfo.GetShape()},
+                                                   {outputTensorInfo.GetShape()});
 }
 
-BOOST_AUTO_TEST_CASE(SerializeDeserializeBatchToSpaceNd)
+BOOST_AUTO_TEST_CASE(SerializeDeserializeGather)
 {
-    class VerifyBatchToSpaceNdName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+    class VerifyGatherName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
     {
     public:
-        void VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer*,
-                                      const armnn::BatchToSpaceNdDescriptor& descriptor,
-                                      const char* name) override
+        void VerifyGatherLayer(const armnn::IConnectableLayer *, const char *name)
         {
-            BOOST_TEST(name == "BatchToSpaceNdLayer");
+            BOOST_TEST(name == "gatherLayer");
         }
     };
 
-    unsigned int inputShape[] = {4, 1, 2, 2};
-    unsigned int outputShape[] = {1, 1, 4, 4};
+    armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QuantisedAsymm8);
+    armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
+    armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QuantisedAsymm8);
 
-    armnn::BatchToSpaceNdDescriptor desc;
-    desc.m_DataLayout = armnn::DataLayout::NCHW;
-    desc.m_BlockShape = {2, 2};
-    desc.m_Crops = {{0, 0}, {0, 0}};
+    paramsInfo.SetQuantizationScale(1.0f);
+    paramsInfo.SetQuantizationOffset(0);
+    outputInfo.SetQuantizationScale(1.0f);
+    outputInfo.SetQuantizationOffset(0);
 
-    auto inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32);
-    auto outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::Float32);
-
+    const std::vector<int32_t>& indicesData = {7, 6, 5};
     armnn::INetworkPtr network = armnn::INetwork::Create();
-    armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
-    armnn::IConnectableLayer* const batchToSpaceNdLayer = network->AddBatchToSpaceNdLayer(desc, "BatchToSpaceNdLayer");
-    armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+    armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
+    armnn::IConnectableLayer *const constantLayer =
+            network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
+    armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer("gatherLayer");
+    armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
 
-    inputLayer->GetOutputSlot(0).Connect(batchToSpaceNdLayer->GetInputSlot(0));
-    inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
-    batchToSpaceNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
-    batchToSpaceNdLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+    inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0));
+    inputLayer->GetOutputSlot(0).SetTensorInfo(paramsInfo);
+    constantLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(1));
+    constantLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo);
+    gatherLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+    gatherLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
 
     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
     BOOST_CHECK(deserializedNetwork);
 
-    VerifyBatchToSpaceNdName nameChecker;
+    VerifyGatherName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<uint8_t>(*deserializedNetwork,
+                                                     *network,
+                                                     {paramsInfo.GetShape()},
+                                                     {outputInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeBatchNormalization)
@@ -991,10 +992,10 @@
     VerifyBatchNormalizationName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputInfo.GetShape()},
-                                            {outputInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*network,
+                                                   *deserializedNetwork,
+                                                   {inputInfo.GetShape()},
+                                                   {outputInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDivision)
@@ -1084,10 +1085,10 @@
     VerifyNormalizationName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputTensorInfo.GetShape()},
+                                                   {outputTensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeEqual)
@@ -1125,11 +1126,11 @@
     VerifyEqualName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo1.GetShape(), inputTensorInfo2.GetShape()},
-                                            {outputTensorInfo.GetShape()},
-                                            {0, 1});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                    *network,
+                                                    {inputTensorInfo1.GetShape(), inputTensorInfo2.GetShape()},
+                                                    {outputTensorInfo.GetShape()},
+                                                    {0, 1});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializePad)
@@ -1166,10 +1167,10 @@
     VerifyPadName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                    *network,
+                                                    {inputTensorInfo.GetShape()},
+                                                    {outputTensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeRsqrt)
@@ -1202,10 +1203,10 @@
     VerifyRsqrtName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {tensorInfo.GetShape()},
-                                            {tensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {tensorInfo.GetShape()},
+                                                   {tensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeResizeBilinear)
@@ -1244,10 +1245,10 @@
     VerifyResizeBilinearName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputTensorInfo.GetShape()},
+                                                   {outputTensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeSubtraction)
@@ -1283,11 +1284,11 @@
     VerifySubtractionName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {info.GetShape(), info.GetShape()},
-                                            {info.GetShape()},
-                                            {0, 1});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {info.GetShape(), info.GetShape()},
+                                                   {info.GetShape()},
+                                                   {0, 1});
 }
 
 BOOST_AUTO_TEST_CASE(SerializeDeserializeStridedSlice)
@@ -1327,10 +1328,10 @@
     VerifyStridedSliceName nameChecker;
     deserializedNetwork->Accept(nameChecker);
 
-    CheckDeserializedNetworkAgainstOriginal(*network,
-                                            *deserializedNetwork,
-                                            {inputTensorInfo.GetShape()},
-                                            {outputTensorInfo.GetShape()});
+    CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork,
+                                                   *network,
+                                                   {inputTensorInfo.GetShape()},
+                                                   {outputTensorInfo.GetShape()});
 }
 
 BOOST_AUTO_TEST_SUITE_END()