IVGCVSW-3526 Add layer norm support for lstm serialization
* Adds layer norm support for serialization/deserialization
* Adds related unit tests
Change-Id: If80b668accc8b0754a93d18ab3a243284cb383d1
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index b59bac60..05df2c9 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -402,7 +402,8 @@
descriptor.m_ClippingThresProj,
descriptor.m_CifgEnabled,
descriptor.m_PeepholeEnabled,
- descriptor.m_ProjectionEnabled);
+ descriptor.m_ProjectionEnabled,
+ descriptor.m_LayerNormEnabled);
// Get mandatory input parameters
auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
@@ -424,6 +425,10 @@
flatbuffers::Offset<serializer::ConstTensor> projectionBias;
flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
if (!descriptor.m_CifgEnabled)
{
@@ -445,6 +450,17 @@
cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
}
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
+ }
+ forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
+ cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
+ outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
+ }
+
auto fbLstmParams = serializer::CreateLstmInputParams(
m_flatBufferBuilder,
inputToForgetWeights,
@@ -463,7 +479,11 @@
projectionWeights,
projectionBias,
cellToForgetWeights,
- cellToOutputWeights);
+ cellToOutputWeights,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
auto fbLstmLayer = serializer::CreateLstmLayer(
m_flatBufferBuilder,