IVGCVSW-8230 Add ScatterNd to Serializer and Deserializer

* Added parsing functions to the serializer and deserializer
* Added ScatterNd and its Descriptor to the ArmnnSchema.fbs
* Added Unittest for Serializer and Deserializer

Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I1ed674dc32d2e2d0d84dca4c7018984ea367ea50
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 966a273..6f6a84d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -842,6 +842,7 @@
             src/armnnDeserializer/test/DeserializeReverseV2.cpp
             src/armnnDeserializer/test/DeserializeRsqrt.cpp
             src/armnnDeserializer/test/DeserializeShape.cpp
+            src/armnnDeserializer/test/DeserializeScatterNd.cpp
             src/armnnDeserializer/test/DeserializeSlice.cpp
             src/armnnDeserializer/test/DeserializeSpaceToBatchNd.cpp
             src/armnnDeserializer/test/DeserializeStridedSlice.cpp
diff --git a/docs/05_02_serializer.dox b/docs/05_02_serializer.dox
index a319178..1bb3d53 100644
--- a/docs/05_02_serializer.dox
+++ b/docs/05_02_serializer.dox
@@ -116,6 +116,8 @@
 
 - ReverseV2
 
+- ScatterNd
+
 - Shape
 
 - Slice
diff --git a/docs/05_06_deserializer.dox b/docs/05_06_deserializer.dox
index bf44dbf..6f5dcf4 100644
--- a/docs/05_06_deserializer.dox
+++ b/docs/05_06_deserializer.dox
@@ -118,6 +118,8 @@
 
 - ReverseV2
 
+- ScatterNd
+
 - Slice
 
 - Softmax
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index f27489f..b77eb0e 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -265,6 +265,7 @@
     m_ParserFunctions[Layer_ResizeLayer]                 = &DeserializerImpl::ParseResize;
     m_ParserFunctions[Layer_ReverseV2Layer]              = &DeserializerImpl::ParseReverseV2;
     m_ParserFunctions[Layer_RsqrtLayer]                  = &DeserializerImpl::ParseRsqrt;
+    m_ParserFunctions[Layer_ScatterNdLayer]              = &DeserializerImpl::ParseScatterNd;
     m_ParserFunctions[Layer_ShapeLayer]                  = &DeserializerImpl::ParseShape;
     m_ParserFunctions[Layer_SliceLayer]                  = &DeserializerImpl::ParseSlice;
     m_ParserFunctions[Layer_SoftmaxLayer]                = &DeserializerImpl::ParseSoftmax;
@@ -402,6 +403,8 @@
             return graphPtr->layers()->Get(layerIndex)->layer_as_ReverseV2Layer()->base();
         case Layer::Layer_RsqrtLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_RsqrtLayer()->base();
+        case Layer::Layer_ScatterNdLayer:
+            return graphPtr->layers()->Get(layerIndex)->layer_as_ScatterNdLayer()->base();
         case Layer::Layer_ShapeLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_ShapeLayer()->base();
         case Layer::Layer_SliceLayer:
@@ -521,6 +524,25 @@
     }
 }
 
+armnn::ScatterNdFunction ToScatterNdFunction(armnnSerializer::ScatterNdFunction function)
+{
+    switch (function)
+    {
+        case armnnSerializer::ScatterNdFunction_Update:
+            return armnn::ScatterNdFunction::Update;
+        case armnnSerializer::ScatterNdFunction_Add:
+            return armnn::ScatterNdFunction::Add;
+        case armnnSerializer::ScatterNdFunction_Sub:
+            return armnn::ScatterNdFunction::Sub;
+        case armnnSerializer::ScatterNdFunction_Max:
+            return armnn::ScatterNdFunction::Max;
+        case armnnSerializer::ScatterNdFunction_Min:
+            return armnn::ScatterNdFunction::Min;
+        default:
+            return armnn::ScatterNdFunction::Update;
+    }
+}
+
 armnn::ComparisonOperation ToComparisonOperation(armnnSerializer::ComparisonOperation operation)
 {
     switch (operation)
@@ -4008,4 +4030,33 @@
     RegisterOutputSlots(graph, layerIndex, layer);
 }
 
+void IDeserializer::DeserializerImpl::ParseScatterNd(GraphPtr graph, unsigned int layerIndex)
+{
+    CHECK_LAYERS(graph, 0, layerIndex);
+    auto inputs = GetInputs(graph, layerIndex);
+    CHECK_LOCATION();
+    CHECK_VALID_SIZE(inputs.size(), 3);
+
+    auto outputs = GetOutputs(graph, layerIndex);
+    CHECK_VALID_SIZE(outputs.size(), 1);
+
+    auto ScatterNdLayer        = graph->layers()->Get(layerIndex)->layer_as_ScatterNdLayer();
+    auto layerName             = GetLayerName(graph, layerIndex);
+    auto flatBufferDescriptor  = ScatterNdLayer->descriptor();
+
+    armnn::ScatterNdDescriptor scatterNdDescriptor;
+    scatterNdDescriptor.m_Function     = ToScatterNdFunction(flatBufferDescriptor->m_Function());
+    scatterNdDescriptor.m_InputEnabled = flatBufferDescriptor->m_InputEnabled();
+    scatterNdDescriptor.m_Axis         = flatBufferDescriptor->m_Axis();
+    scatterNdDescriptor.m_AxisEnabled  = flatBufferDescriptor->m_AxisEnabled();
+
+    IConnectableLayer* layer = m_Network->AddScatterNdLayer(scatterNdDescriptor, layerName.c_str());
+
+    armnn::TensorInfo output0TensorInfo = ToTensorInfo(outputs[0]);
+    layer->GetOutputSlot(0).SetTensorInfo(output0TensorInfo);
+
+    RegisterInputSlots(graph, layerIndex, layer);
+    RegisterOutputSlots(graph, layerIndex, layer);
+}
+
 } // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 7e427d6..4b29a70 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -138,6 +138,7 @@
     void ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex);
     void ParseReverseV2(GraphPtr graph, unsigned int layerIndex);
     void ParseRsqrt(GraphPtr graph, unsigned int layerIndex);
+    void ParseScatterNd(GraphPtr graph, unsigned int layerIndex);
     void ParseShape(GraphPtr graph, unsigned int layerIndex);
     void ParseSlice(GraphPtr graph, unsigned int layerIndex);
     void ParseSoftmax(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/test/DeserializeScatterNd.cpp b/src/armnnDeserializer/test/DeserializeScatterNd.cpp
new file mode 100644
index 0000000..3e88310
--- /dev/null
+++ b/src/armnnDeserializer/test/DeserializeScatterNd.cpp
@@ -0,0 +1,179 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserFlatbuffersSerializeFixture.hpp"
+#include <armnnDeserializer/IDeserializer.hpp>
+
+#include <string>
+
+TEST_SUITE("Deserializer_ScatterNd")
+{
+struct ScatterNdFixture : public ParserFlatbuffersSerializeFixture
+{
+    explicit ScatterNdFixture(const std::string& inputShape,
+                              const std::string& indicesShape,
+                              const std::string& updatesShape,
+                              const std::string& outputShape,
+                              const std::string& indicesData,
+                              const std::string& updatesData,
+                              const std::string dataType,
+                              const std::string constDataType)
+    {
+        m_JsonString = R"(
+        {
+                inputIds: [0],
+                outputIds: [4],
+                layers: [
+                {
+                    layer_type: "InputLayer",
+                    layer: {
+                          base: {
+                                layerBindingId: 0,
+                                base: {
+                                    index: 0,
+                                    layerName: "InputLayer",
+                                    layerType: "Input",
+                                    inputSlots: [{
+                                        index: 0,
+                                        connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+                                    }],
+                                    outputSlots: [ {
+                                        index: 0,
+                                        tensorInfo: {
+                                            dimensions: )" + inputShape + R"(,
+                                            dataType: )" + dataType + R"(
+                                            }}]
+                                    }
+                    }}},
+                    {
+                    layer_type: "ConstantLayer",
+                        layer: {
+                               base: {
+                                  index:1,
+                                  layerName: "ConstantLayer",
+                                  layerType: "Constant",
+                                   outputSlots: [ {
+                                    index: 0,
+                                    tensorInfo: {
+                                        dimensions: )" + indicesShape + R"(,
+                                        dataType: "Signed32",
+                                    },
+                                  }],
+                              },
+                              input: {
+                              info: {
+                                       dimensions: )" + indicesShape + R"(,
+                                       dataType: "Signed32",
+                                   },
+                              data_type: )" + constDataType + R"(,
+                              data: {
+                                  data: )" + indicesData + R"(,
+                                    } }
+                                },},
+                    {
+                    layer_type: "ConstantLayer",
+                        layer: {
+                               base: {
+                                  index:2,
+                                  layerName: "ConstantLayer",
+                                  layerType: "Constant",
+                                   outputSlots: [ {
+                                    index: 0,
+                                    tensorInfo: {
+                                        dimensions: )" + updatesShape + R"(,
+                                        dataType: )" + dataType + R"(
+                                    },
+                                  }],
+                              },
+                              input: {
+                              info: {
+                                       dimensions: )" + updatesShape + R"(,
+                                       dataType: )" + dataType + R"(
+                                   },
+                              data_type: )" + constDataType + R"(,
+                              data: {
+                                  data: )" + updatesData + R"(,
+                                    } }
+                                },},
+                    {
+                    layer_type: "ScatterNdLayer",
+                        layer: {
+                              base: {
+                                   index: 3,
+                                   layerName: "ScatterNdLayer",
+                                   layerType: "ScatterNd",
+                                   inputSlots: [
+                                   {
+                                       index: 0,
+                                       connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+                                   },
+                                   {
+                                       index: 1,
+                                       connection: {sourceLayerIndex:1, outputSlotIndex:0 },
+                                   },
+                                   {
+                                       index: 2,
+                                       connection: {sourceLayerIndex:2, outputSlotIndex:0 },
+                                   }],
+                                   outputSlots: [ {
+                                          index: 0,
+                                          tensorInfo: {
+                                               dimensions: )" + outputShape + R"(,
+                                               dataType: )" + dataType + R"(
+
+                                   }}]},
+                                    descriptor: {
+                                        m_Function: Update,
+                                        m_InputEnabled: true,
+                                        m_Axis: 0,
+                                        m_AxisEnabled: false
+                                        },
+                        }},
+                    {
+                    layer_type: "OutputLayer",
+                    layer: {
+                        base:{
+                              layerBindingId: 0,
+                              base: {
+                                    index: 4,
+                                    layerName: "OutputLayer",
+                                    layerType: "Output",
+                                    inputSlots: [{
+                                        index: 0,
+                                        connection: {sourceLayerIndex:3, outputSlotIndex:0 },
+                                    }],
+                                    outputSlots: [ {
+                                        index: 0,
+                                        tensorInfo: {
+                                            dimensions: )" + outputShape + R"(,
+                                            dataType: )" + dataType + R"(
+                                        },
+                                }],
+                            }}},
+                }],
+                featureVersions: {
+                    weightsLayoutScheme: 1,
+                }
+                 } )";
+
+        Setup();
+    }
+};
+
+struct SimpleScatterNdFixtureSigned32 : ScatterNdFixture
+{
+    SimpleScatterNdFixtureSigned32() : ScatterNdFixture("[ 5 ]", "[ 3, 1 ]", "[ 3 ]", "[ 5 ]",
+                                                       "[ 0, 1, 2 ]", "[ 1, 2, 3 ]", "Signed32", "IntData") {}
+};
+
+TEST_CASE_FIXTURE(SimpleScatterNdFixtureSigned32, "ScatterNdSigned32")
+{
+    RunTest<1, armnn::DataType::Signed32>(0,
+                                         {{"InputLayer", {  0, 0, 0, 0, 0 }}},
+                                         {{"OutputLayer", { 1, 2, 3, 0, 0 }}});
+}
+
+}
+
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 131970e..3a01c50 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -64,6 +64,14 @@
     Bilinear = 1,
 }
 
+enum ScatterNdFunction: byte {
+    Update = 0,
+    Add    = 1,
+    Sub    = 2,
+    Max    = 3,
+    Min    = 4
+}
+
 table TensorInfo {
     dimensions:[uint];
     dataType:DataType;
@@ -189,6 +197,7 @@
     ElementwiseBinary = 69,
     ReverseV2 = 70,
     Tile = 71,
+    ScatterNd = 72,
 }
 
 // Base layer table to be used as part of other layers
@@ -1066,6 +1075,18 @@
     descriptor:TileDescriptor;
 }
 
+table ScatterNdDescriptor {
+   m_Function:ScatterNdFunction = Update;
+   m_InputEnabled:bool = true;
+   m_Axis:int = 0;
+   m_AxisEnabled:bool = false;
+}
+
+table ScatterNdLayer {
+    base:LayerBase;
+    descriptor:ScatterNdDescriptor;
+}
+
 union Layer {
     ActivationLayer,
     AdditionLayer,
@@ -1139,6 +1160,7 @@
     ElementwiseBinaryLayer,
     ReverseV2Layer,
     TileLayer,
+    ScatterNdLayer,
 }
 
 table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index ef2ca48..ffdac43 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #include "Serializer.hpp"
@@ -97,6 +97,25 @@
     }
 }
 
+serializer::ScatterNdFunction GetFlatBufferScatterNdFunction(armnn::ScatterNdFunction function)
+{
+    switch (function)
+    {
+        case armnn::ScatterNdFunction::Update:
+            return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+        case armnn::ScatterNdFunction::Add:
+            return serializer::ScatterNdFunction::ScatterNdFunction_Add;
+        case armnn::ScatterNdFunction::Sub:
+            return serializer::ScatterNdFunction::ScatterNdFunction_Sub;
+        case armnn::ScatterNdFunction::Max:
+            return serializer::ScatterNdFunction::ScatterNdFunction_Max;
+        case armnn::ScatterNdFunction::Min:
+            return serializer::ScatterNdFunction::ScatterNdFunction_Min;
+        default:
+            return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+    }
+}
+
 uint32_t SerializerStrategy::GetSerializedId(LayerGuid guid)
 {
     if (m_guidMap.empty())
@@ -1347,6 +1366,32 @@
     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
 }
 
+void SerializerStrategy::SerializeScatterNdLayer(const armnn::IConnectableLayer* layer,
+                                                 const armnn::ScatterNdDescriptor& descriptor,
+                                                 const char* name)
+{
+    IgnoreUnused(name);
+
+    // Create FlatBuffer BaseLayer
+    auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ScatterNd);
+
+    auto flatBufferDesc = serializer::CreateScatterNdDescriptor(
+            m_flatBufferBuilder,
+            GetFlatBufferScatterNdFunction(descriptor.m_Function),
+            descriptor.m_InputEnabled,
+            descriptor.m_Axis,
+            descriptor.m_AxisEnabled);
+
+    // Create the FlatBuffer TileLayer
+    auto flatBufferLayer = serializer::CreateScatterNdLayer(
+            m_flatBufferBuilder,
+            flatBufferBaseLayer,
+            flatBufferDesc);
+
+    // Add the AnyLayer to the FlatBufferLayers
+    CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ScatterNdLayer);
+}
+
 void SerializerStrategy::SerializeShapeLayer(const armnn::IConnectableLayer* layer,
                                              const char* name)
 {
@@ -2379,6 +2424,13 @@
             SerializeReverseV2Layer(layer, name);
             break;
         }
+        case armnn::LayerType::ScatterNd:
+        {
+            const armnn::ScatterNdDescriptor& layerDescriptor =
+                    static_cast<const armnn::ScatterNdDescriptor&>(descriptor);
+            SerializeScatterNdLayer(layer, layerDescriptor, name);
+            break;
+        }
         case armnn::LayerType::Shape:
         {
             SerializeShapeLayer(layer, name);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index afff66e..7434d63 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #pragma once
@@ -312,6 +312,10 @@
                                      const armnn::NormalizationDescriptor& normalizationDescriptor,
                                      const char* name = nullptr);
 
+    void SerializeScatterNdLayer(const armnn::IConnectableLayer* layer,
+                                 const armnn::ScatterNdDescriptor& descriptor,
+                                 const char* name);
+
     void SerializeShapeLayer(const armnn::IConnectableLayer* layer,
                              const char* name = nullptr);
 
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index bfe3fc6..37acb0c 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017,2020-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2020-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -3065,4 +3065,46 @@
     deserializedNetwork->ExecuteStrategy(verifier);
 }
 
+TEST_CASE("SerializeScatterNd")
+{
+    const std::string layerName("ScatterNd");
+    const armnn::TensorInfo inputInfo ({ 5 }, armnn::DataType::Float32);
+    const armnn::TensorInfo outputInfo ({ 5 }, armnn::DataType::Float32);
+    const armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Float32, 0.0f, 0, true);
+    const armnn::TensorInfo updatesInfo ({ 3 }, armnn::DataType::Float32,0.0f, 0, true);
+    std::vector<float> indicesData = { 0, 2, 3 };
+    const armnn::ConstTensor indices(indicesInfo, indicesData);
+
+    std::vector<float> updatesData = { 4, 5, 6 };
+    const armnn::ConstTensor updates(updatesInfo, updatesData);
+
+    armnn::ScatterNdDescriptor desc;
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+    armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+    armnn::IConnectableLayer* const indicesLayer = network->AddConstantLayer(indices, "Indices");
+    armnn::IConnectableLayer* const updatesLayer = network->AddConstantLayer(updates, "Updates");
+    armnn::IConnectableLayer* const scatterNdLayer = network->AddScatterNdLayer(desc, layerName.c_str());
+    armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+    inputLayer->GetOutputSlot(0).Connect(scatterNdLayer->GetInputSlot(0));
+    indicesLayer->GetOutputSlot(0).Connect(scatterNdLayer->GetInputSlot(1));
+    updatesLayer->GetOutputSlot(0).Connect(scatterNdLayer->GetInputSlot(2));
+    scatterNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+    inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+    indicesLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo);
+    updatesLayer->GetOutputSlot(0).SetTensorInfo(updatesInfo);
+    scatterNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    CHECK(deserializedNetwork);
+
+    LayerVerifierBaseWithDescriptor<armnn::ScatterNdDescriptor> verifier(layerName,
+                                                                         {inputInfo, indicesInfo, updatesInfo},
+                                                                         {outputInfo},
+                                                                         desc);
+    deserializedNetwork->ExecuteStrategy(verifier);
+}
+
 }