IVGCVSW-4777 Add QLstm serialization support
* Adds serialization/deserilization for QLstm.
* 3 unit tests: basic, layer norm and advanced.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I97d825e06b0d4a1257713cdd71ff06afa10d4380
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index ff79f6c..6e5ee3f 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -155,7 +155,8 @@
Comparison = 52,
StandIn = 53,
ElementwiseUnary = 54,
- Transpose = 55
+ Transpose = 55,
+ QLstm = 56
}
// Base layer table to be used as part of other layers
@@ -666,6 +667,81 @@
outputLayerNormWeights:ConstTensor;
}
+table LstmDescriptor {
+ activationFunc:uint;
+ clippingThresCell:float;
+ clippingThresProj:float;
+ cifgEnabled:bool = true;
+ peepholeEnabled:bool = false;
+ projectionEnabled:bool = false;
+ layerNormEnabled:bool = false;
+}
+
+table LstmLayer {
+ base:LayerBase;
+ descriptor:LstmDescriptor;
+ inputParams:LstmInputParams;
+}
+
+table QLstmInputParams {
+ // Mandatory
+ inputToForgetWeights:ConstTensor;
+ inputToCellWeights:ConstTensor;
+ inputToOutputWeights:ConstTensor;
+
+ recurrentToForgetWeights:ConstTensor;
+ recurrentToCellWeights:ConstTensor;
+ recurrentToOutputWeights:ConstTensor;
+
+ forgetGateBias:ConstTensor;
+ cellBias:ConstTensor;
+ outputGateBias:ConstTensor;
+
+ // CIFG
+ inputToInputWeights:ConstTensor;
+ recurrentToInputWeights:ConstTensor;
+ inputGateBias:ConstTensor;
+
+ // Projection
+ projectionWeights:ConstTensor;
+ projectionBias:ConstTensor;
+
+ // Peephole
+ cellToInputWeights:ConstTensor;
+ cellToForgetWeights:ConstTensor;
+ cellToOutputWeights:ConstTensor;
+
+ // Layer norm
+ inputLayerNormWeights:ConstTensor;
+ forgetLayerNormWeights:ConstTensor;
+ cellLayerNormWeights:ConstTensor;
+ outputLayerNormWeights:ConstTensor;
+}
+
+table QLstmDescriptor {
+ cifgEnabled:bool = true;
+ peepholeEnabled:bool = false;
+ projectionEnabled:bool = false;
+ layerNormEnabled:bool = false;
+
+ cellClip:float;
+ projectionClip:float;
+
+ inputIntermediateScale:float;
+ forgetIntermediateScale:float;
+ cellIntermediateScale:float;
+ outputIntermediateScale:float;
+
+ hiddenStateZeroPoint:int;
+ hiddenStateScale:float;
+}
+
+table QLstmLayer {
+ base:LayerBase;
+ descriptor:QLstmDescriptor;
+ inputParams:QLstmInputParams;
+}
+
table QuantizedLstmInputParams {
inputToInputWeights:ConstTensor;
inputToForgetWeights:ConstTensor;
@@ -683,22 +759,6 @@
outputGateBias:ConstTensor;
}
-table LstmDescriptor {
- activationFunc:uint;
- clippingThresCell:float;
- clippingThresProj:float;
- cifgEnabled:bool = true;
- peepholeEnabled:bool = false;
- projectionEnabled:bool = false;
- layerNormEnabled:bool = false;
-}
-
-table LstmLayer {
- base:LayerBase;
- descriptor:LstmDescriptor;
- inputParams:LstmInputParams;
-}
-
table QuantizedLstmLayer {
base:LayerBase;
inputParams:QuantizedLstmInputParams;
@@ -836,7 +896,8 @@
ComparisonLayer,
StandInLayer,
ElementwiseUnaryLayer,
- TransposeLayer
+ TransposeLayer,
+ QLstmLayer
}
table AnyLayer {