IVGCVSW-8230 Add ScatterNd to Serializer and Deserializer
* Added parsing functions to the serializer and deserializer
* Added ScatterNd and its Descriptor to the ArmnnSchema.fbs
* Added Unittest for Serializer and Deserializer
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I1ed674dc32d2e2d0d84dca4c7018984ea367ea50
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index ef2ca48..ffdac43 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "Serializer.hpp"
@@ -97,6 +97,25 @@
}
}
+serializer::ScatterNdFunction GetFlatBufferScatterNdFunction(armnn::ScatterNdFunction function)
+{
+ switch (function)
+ {
+ case armnn::ScatterNdFunction::Update:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+ case armnn::ScatterNdFunction::Add:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Add;
+ case armnn::ScatterNdFunction::Sub:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Sub;
+ case armnn::ScatterNdFunction::Max:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Max;
+ case armnn::ScatterNdFunction::Min:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Min;
+ default:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+ }
+}
+
uint32_t SerializerStrategy::GetSerializedId(LayerGuid guid)
{
if (m_guidMap.empty())
@@ -1347,6 +1366,32 @@
CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
}
+void SerializerStrategy::SerializeScatterNdLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ScatterNdDescriptor& descriptor,
+ const char* name)
+{
+ IgnoreUnused(name);
+
+ // Create FlatBuffer BaseLayer
+ auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ScatterNd);
+
+ auto flatBufferDesc = serializer::CreateScatterNdDescriptor(
+ m_flatBufferBuilder,
+ GetFlatBufferScatterNdFunction(descriptor.m_Function),
+ descriptor.m_InputEnabled,
+ descriptor.m_Axis,
+ descriptor.m_AxisEnabled);
+
+ // Create the FlatBuffer TileLayer
+ auto flatBufferLayer = serializer::CreateScatterNdLayer(
+ m_flatBufferBuilder,
+ flatBufferBaseLayer,
+ flatBufferDesc);
+
+ // Add the AnyLayer to the FlatBufferLayers
+ CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ScatterNdLayer);
+}
+
void SerializerStrategy::SerializeShapeLayer(const armnn::IConnectableLayer* layer,
const char* name)
{
@@ -2379,6 +2424,13 @@
SerializeReverseV2Layer(layer, name);
break;
}
+ case armnn::LayerType::ScatterNd:
+ {
+ const armnn::ScatterNdDescriptor& layerDescriptor =
+ static_cast<const armnn::ScatterNdDescriptor&>(descriptor);
+ SerializeScatterNdLayer(layer, layerDescriptor, name);
+ break;
+ }
case armnn::LayerType::Shape:
{
SerializeShapeLayer(layer, name);