IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests
Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c
Signed-off-by: Nina Drozd <nina.drozd@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index a11eead..2cceaae 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -115,7 +115,8 @@
Merger = 30,
L2Normalization = 31,
Splitter = 32,
- DetectionPostProcess = 33
+ DetectionPostProcess = 33,
+ Lstm = 34
}
// Base layer table to be used as part of other layers
@@ -475,6 +476,44 @@
scaleH:float;
}
+table LstmInputParams {
+ inputToForgetWeights:ConstTensor;
+ inputToCellWeights:ConstTensor;
+ inputToOutputWeights:ConstTensor;
+ recurrentToForgetWeights:ConstTensor;
+ recurrentToCellWeights:ConstTensor;
+ recurrentToOutputWeights:ConstTensor;
+ forgetGateBias:ConstTensor;
+ cellBias:ConstTensor;
+ outputGateBias:ConstTensor;
+
+ inputToInputWeights:ConstTensor;
+ recurrentToInputWeights:ConstTensor;
+ cellToInputWeights:ConstTensor;
+ inputGateBias:ConstTensor;
+
+ projectionWeights:ConstTensor;
+ projectionBias:ConstTensor;
+
+ cellToForgetWeights:ConstTensor;
+ cellToOutputWeights:ConstTensor;
+}
+
+table LstmDescriptor {
+ activationFunc:uint;
+ clippingThresCell:float;
+ clippingThresProj:float;
+ cifgEnabled:bool = true;
+ peepholeEnabled:bool = false;
+ projectionEnabled:bool = false;
+}
+
+table LstmLayer {
+ base:LayerBase;
+ descriptor:LstmDescriptor;
+ inputParams:LstmInputParams;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -509,7 +548,8 @@
MergerLayer,
L2NormalizationLayer,
SplitterLayer,
- DetectionPostProcessLayer
+ DetectionPostProcessLayer,
+ LstmLayer
}
table AnyLayer {