IVGCVSW-2694: serialize/deserialize LSTM

* added serialize/deserialize methods for LSTM and tests

Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c
Signed-off-by: Nina Drozd <nina.drozd@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index fa836d0..2b99f28 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -252,110 +252,144 @@
 void LstmLayer::Accept(ILayerVisitor& visitor) const
 {
     LstmInputParams inputParams;
+    ConstTensor inputToInputWeightsTensor;
     if (m_CifgParameters.m_InputToInputWeights != nullptr)
     {
-        ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
-                                              m_CifgParameters.m_InputToInputWeights->Map(true));
+        ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
+                                                  m_CifgParameters.m_InputToInputWeights->Map(true));
+        inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
         inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
     }
+    ConstTensor inputToForgetWeightsTensor;
     if (m_BasicParameters.m_InputToForgetWeights != nullptr)
     {
-        ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
-                                               m_BasicParameters.m_InputToForgetWeights->Map(true));
+        ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
+                                                   m_BasicParameters.m_InputToForgetWeights->Map(true));
+        inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
         inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
     }
+    ConstTensor inputToCellWeightsTensor;
     if (m_BasicParameters.m_InputToCellWeights != nullptr)
     {
-        ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
-                                             m_BasicParameters.m_InputToCellWeights->Map(true));
+        ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
+                                                 m_BasicParameters.m_InputToCellWeights->Map(true));
+        inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
         inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
     }
+    ConstTensor inputToOutputWeightsTensor;
     if (m_BasicParameters.m_InputToOutputWeights != nullptr)
     {
-        ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
-                                               m_BasicParameters.m_InputToOutputWeights->Map(true));
+        ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
+                                                   m_BasicParameters.m_InputToOutputWeights->Map(true));
+        inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
         inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
     }
+    ConstTensor recurrentToInputWeightsTensor;
     if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
     {
-        ConstTensor recurrentToInputWeightsTensor(
+        ConstTensor recurrentToInputWeightsTensorCopy(
                 m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(),
                 m_CifgParameters.m_RecurrentToInputWeights->Map(true));
+        recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
         inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
     }
+    ConstTensor recurrentToForgetWeightsTensor;
     if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
     {
-        ConstTensor recurrentToForgetWeightsTensor(
+        ConstTensor recurrentToForgetWeightsTensorCopy(
                 m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
                 m_BasicParameters.m_RecurrentToForgetWeights->Map(true));
+        recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
         inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
     }
+    ConstTensor recurrentToCellWeightsTensor;
     if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
     {
-        ConstTensor recurrentToCellWeightsTensor(
+        ConstTensor recurrentToCellWeightsTensorCopy(
                 m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(),
                 m_BasicParameters.m_RecurrentToCellWeights->Map(true));
+        recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
         inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
     }
+    ConstTensor recurrentToOutputWeightsTensor;
     if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
     {
-        ConstTensor recurrentToOutputWeightsTensor(
+        ConstTensor recurrentToOutputWeightsTensorCopy(
                 m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
                 m_BasicParameters.m_RecurrentToOutputWeights->Map(true));
+        recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
         inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
     }
+    ConstTensor cellToInputWeightsTensor;
     if (m_CifgParameters.m_CellToInputWeights != nullptr)
     {
-        ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
-                                             m_CifgParameters.m_CellToInputWeights->Map(true));
+        ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
+                                                 m_CifgParameters.m_CellToInputWeights->Map(true));
+        cellToInputWeightsTensor = cellToInputWeightsTensorCopy;
         inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
     }
+    ConstTensor cellToForgetWeightsTensor;
     if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
     {
-        ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
-                                              m_PeepholeParameters.m_CellToForgetWeights->Map(true));
+        ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
+                                                  m_PeepholeParameters.m_CellToForgetWeights->Map(true));
+        cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy;
         inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor;
     }
+    ConstTensor cellToOutputWeightsTensor;
     if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
     {
-        ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
-                                              m_PeepholeParameters.m_CellToOutputWeights->Map(true));
+        ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
+                                                  m_PeepholeParameters.m_CellToOutputWeights->Map(true));
+        cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy;
         inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor;
     }
+    ConstTensor inputGateBiasTensor;
     if (m_CifgParameters.m_InputGateBias != nullptr)
     {
-        ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
+        ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
                                         m_CifgParameters.m_InputGateBias->Map(true));
+        inputGateBiasTensor = inputGateBiasTensorCopy;
         inputParams.m_InputGateBias = &inputGateBiasTensor;
     }
+    ConstTensor forgetGateBiasTensor;
     if (m_BasicParameters.m_ForgetGateBias != nullptr)
     {
-        ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
-                                         m_BasicParameters.m_ForgetGateBias->Map(true));
+        ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
+                                             m_BasicParameters.m_ForgetGateBias->Map(true));
+        forgetGateBiasTensor = forgetGateBiasTensorCopy;
         inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
     }
+    ConstTensor cellBiasTensor;
     if (m_BasicParameters.m_CellBias != nullptr)
     {
-        ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(),
-                                   m_BasicParameters.m_CellBias->Map(true));
+        ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(),
+                                       m_BasicParameters.m_CellBias->Map(true));
+        cellBiasTensor = cellBiasTensorCopy;
         inputParams.m_CellBias = &cellBiasTensor;
     }
+    ConstTensor outputGateBias;
     if (m_BasicParameters.m_OutputGateBias != nullptr)
     {
-        ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
-                                   m_BasicParameters.m_OutputGateBias->Map(true));
+        ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
+                                       m_BasicParameters.m_OutputGateBias->Map(true));
+        outputGateBias = outputGateBiasCopy;
         inputParams.m_OutputGateBias = &outputGateBias;
     }
+    ConstTensor projectionWeightsTensor;
     if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
     {
-        ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
-                                            m_ProjectionParameters.m_ProjectionWeights->Map(true));
+        ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
+                                                m_ProjectionParameters.m_ProjectionWeights->Map(true));
+        projectionWeightsTensor = projectionWeightsTensorCopy;
         inputParams.m_ProjectionWeights = &projectionWeightsTensor;
     }
+    ConstTensor projectionBiasTensor;
     if (m_ProjectionParameters.m_ProjectionBias != nullptr)
     {
-        ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
-                                         m_ProjectionParameters.m_ProjectionBias->Map(true));
+        ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
+                                             m_ProjectionParameters.m_ProjectionBias->Map(true));
+        projectionBiasTensor = projectionBiasTensorCopy;
         inputParams.m_ProjectionBias = &projectionBiasTensor;
     }
 
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 152a5b4..d64bed7 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -201,6 +201,7 @@
     m_ParserFunctions[Layer_GatherLayer]                 = &Deserializer::ParseGather;
     m_ParserFunctions[Layer_GreaterLayer]                = &Deserializer::ParseGreater;
     m_ParserFunctions[Layer_L2NormalizationLayer]        = &Deserializer::ParseL2Normalization;
+    m_ParserFunctions[Layer_LstmLayer]                   = &Deserializer::ParseLstm;
     m_ParserFunctions[Layer_MaximumLayer]                = &Deserializer::ParseMaximum;
     m_ParserFunctions[Layer_MeanLayer]                   = &Deserializer::ParseMean;
     m_ParserFunctions[Layer_MinimumLayer]                = &Deserializer::ParseMinimum;
@@ -258,6 +259,8 @@
             return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base();
         case Layer::Layer_L2NormalizationLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer()->base();
+        case Layer::Layer_LstmLayer:
+            return graphPtr->layers()->Get(layerIndex)->layer_as_LstmLayer()->base();
         case Layer::Layer_MeanLayer:
             return graphPtr->layers()->Get(layerIndex)->layer_as_MeanLayer()->base();
         case Layer::Layer_MinimumLayer:
@@ -1927,4 +1930,114 @@
     RegisterOutputSlots(graph, layerIndex, layer);
 }
 
+armnn::LstmDescriptor Deserializer::GetLstmDescriptor(Deserializer::LstmDescriptorPtr lstmDescriptor)
+{
+    armnn::LstmDescriptor desc;
+
+    desc.m_ActivationFunc = lstmDescriptor->activationFunc();
+    desc.m_ClippingThresCell = lstmDescriptor->clippingThresCell();
+    desc.m_ClippingThresProj = lstmDescriptor->clippingThresProj();
+    desc.m_CifgEnabled = lstmDescriptor->cifgEnabled();
+    desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled();
+    desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled();
+
+    return desc;
+}
+
+void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex)
+{
+    CHECK_LAYERS(graph, 0, layerIndex);
+
+    auto inputs = GetInputs(graph, layerIndex);
+    CHECK_VALID_SIZE(inputs.size(), 3);
+
+    auto outputs = GetOutputs(graph, layerIndex);
+    CHECK_VALID_SIZE(outputs.size(), 4);
+
+    auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_LstmLayer();
+    auto layerName = GetLayerName(graph, layerIndex);
+    auto flatBufferDescriptor = flatBufferLayer->descriptor();
+    auto flatBufferInputParams = flatBufferLayer->inputParams();
+
+    auto lstmDescriptor = GetLstmDescriptor(flatBufferDescriptor);
+
+    armnn::LstmInputParams lstmInputParams;
+
+    armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
+    armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
+    armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
+    armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
+    armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
+    armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
+    armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
+    armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
+    armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
+
+    lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
+    lstmInputParams.m_InputToCellWeights = &inputToCellWeights;
+    lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
+    lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+    lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
+    lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+    lstmInputParams.m_ForgetGateBias = &forgetGateBias;
+    lstmInputParams.m_CellBias = &cellBias;
+    lstmInputParams.m_OutputGateBias = &outputGateBias;
+
+    armnn::ConstTensor inputToInputWeights;
+    armnn::ConstTensor recurrentToInputWeights;
+    armnn::ConstTensor cellToInputWeights;
+    armnn::ConstTensor inputGateBias;
+    if (!lstmDescriptor.m_CifgEnabled)
+    {
+        inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
+        recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
+        cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
+        inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
+
+        lstmInputParams.m_InputToInputWeights = &inputToInputWeights;
+        lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
+        lstmInputParams.m_CellToInputWeights = &cellToInputWeights;
+        lstmInputParams.m_InputGateBias = &inputGateBias;
+    }
+
+    armnn::ConstTensor projectionWeights;
+    armnn::ConstTensor projectionBias;
+    if (lstmDescriptor.m_ProjectionEnabled)
+    {
+        projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
+        projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
+
+        lstmInputParams.m_ProjectionWeights = &projectionWeights;
+        lstmInputParams.m_ProjectionBias = &projectionBias;
+    }
+
+    armnn::ConstTensor cellToForgetWeights;
+    armnn::ConstTensor cellToOutputWeights;
+    if (lstmDescriptor.m_PeepholeEnabled)
+    {
+        cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
+        cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
+
+        lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
+        lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
+    }
+
+    IConnectableLayer* layer = m_Network->AddLstmLayer(lstmDescriptor, lstmInputParams, layerName.c_str());
+
+    armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]);
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1);
+
+    armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[1]);
+    layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo2);
+
+    armnn::TensorInfo outputTensorInfo3 = ToTensorInfo(outputs[2]);
+    layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo3);
+
+    armnn::TensorInfo outputTensorInfo4 = ToTensorInfo(outputs[3]);
+    layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo4);
+
+    RegisterInputSlots(graph, layerIndex, layer);
+    RegisterOutputSlots(graph, layerIndex, layer);
+}
+
 } // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index effc7ae..6454643 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -22,6 +22,8 @@
     using TensorRawPtr = const armnnSerializer::TensorInfo *;
     using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
     using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
+    using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
+    using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
     using TensorRawPtrVector = std::vector<TensorRawPtr>;
     using LayerRawPtr = const armnnSerializer::LayerBase *;
     using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
@@ -58,6 +60,9 @@
                                                            unsigned int layerIndex);
     static armnn::NormalizationDescriptor GetNormalizationDescriptor(
         NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
+    static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
+    static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
+                                                     LstmInputParamsPtr lstmInputParams);
     static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
                                                   const std::vector<uint32_t> & targetDimsIn);
 
@@ -94,6 +99,7 @@
     void ParseMerger(GraphPtr graph, unsigned int layerIndex);
     void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
     void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
+    void ParseLstm(GraphPtr graph, unsigned int layerIndex);
     void ParsePad(GraphPtr graph, unsigned int layerIndex);
     void ParsePermute(GraphPtr graph, unsigned int layerIndex);
     void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index 48b8c88..d53252e 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -21,6 +21,7 @@
 * Gather
 * Greater
 * L2Normalization
+* Lstm
 * Maximum
 * Mean
 * Merger
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index a11eead..2cceaae 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -115,7 +115,8 @@
     Merger = 30,
     L2Normalization = 31,
     Splitter = 32,
-    DetectionPostProcess = 33
+    DetectionPostProcess = 33,
+    Lstm = 34
 }
 
 // Base layer table to be used as part of other layers
@@ -475,6 +476,44 @@
     scaleH:float;
 }
 
+table LstmInputParams {
+    inputToForgetWeights:ConstTensor;
+    inputToCellWeights:ConstTensor;
+    inputToOutputWeights:ConstTensor;
+    recurrentToForgetWeights:ConstTensor;
+    recurrentToCellWeights:ConstTensor;
+    recurrentToOutputWeights:ConstTensor;
+    forgetGateBias:ConstTensor;
+    cellBias:ConstTensor;
+    outputGateBias:ConstTensor;
+
+    inputToInputWeights:ConstTensor;
+    recurrentToInputWeights:ConstTensor;
+    cellToInputWeights:ConstTensor;
+    inputGateBias:ConstTensor;
+
+    projectionWeights:ConstTensor;
+    projectionBias:ConstTensor;
+
+    cellToForgetWeights:ConstTensor;
+    cellToOutputWeights:ConstTensor;
+}
+
+table LstmDescriptor {
+    activationFunc:uint;
+    clippingThresCell:float;
+    clippingThresProj:float;
+    cifgEnabled:bool = true;
+    peepholeEnabled:bool = false;
+    projectionEnabled:bool = false;
+}
+
+table LstmLayer {
+    base:LayerBase;
+    descriptor:LstmDescriptor;
+    inputParams:LstmInputParams;
+}
+
 union Layer {
     ActivationLayer,
     AdditionLayer,
@@ -509,7 +548,8 @@
     MergerLayer,
     L2NormalizationLayer,
     SplitterLayer,
-    DetectionPostProcessLayer
+    DetectionPostProcessLayer,
+    LstmLayer
 }
 
 table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index a27cbc0..2fd8402 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -375,6 +375,90 @@
     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
 }
 
+void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor,
+                                       const armnn::LstmInputParams& params, const char* name)
+{
+    auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
+
+    auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
+        m_flatBufferBuilder,
+        descriptor.m_ActivationFunc,
+        descriptor.m_ClippingThresCell,
+        descriptor.m_ClippingThresProj,
+        descriptor.m_CifgEnabled,
+        descriptor.m_PeepholeEnabled,
+        descriptor.m_ProjectionEnabled);
+
+    // Get mandatory input parameters
+    auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
+    auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
+    auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
+    auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
+    auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
+    auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
+    auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
+    auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
+    auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
+
+    //Define optional parameters, these will be set depending on configuration in Lstm descriptor
+    flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
+    flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
+    flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
+    flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
+    flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
+    flatbuffers::Offset<serializer::ConstTensor> projectionBias;
+    flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
+    flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+
+    if (!descriptor.m_CifgEnabled)
+    {
+        inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
+        recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
+        cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
+        inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
+    }
+
+    if (descriptor.m_ProjectionEnabled)
+    {
+        projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
+        projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
+    }
+
+    if (descriptor.m_PeepholeEnabled)
+    {
+        cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
+        cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
+    }
+
+    auto fbLstmParams = serializer::CreateLstmInputParams(
+        m_flatBufferBuilder,
+        inputToForgetWeights,
+        inputToCellWeights,
+        inputToOutputWeights,
+        recurrentToForgetWeights,
+        recurrentToCellWeights,
+        recurrentToOutputWeights,
+        forgetGateBias,
+        cellBias,
+        outputGateBias,
+        inputToInputWeights,
+        recurrentToInputWeights,
+        cellToInputWeights,
+        inputGateBias,
+        projectionWeights,
+        projectionBias,
+        cellToForgetWeights,
+        cellToOutputWeights);
+
+    auto fbLstmLayer = serializer::CreateLstmLayer(
+        m_flatBufferBuilder,
+        fbLstmBaseLayer,
+        fbLstmDescriptor,
+        fbLstmParams);
+
+    CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
+}
+
 void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
 {
     auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 71066d2..4573bfd 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -111,6 +111,11 @@
                                    const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
                                    const char* name = nullptr) override;
 
+    void VisitLstmLayer(const armnn::IConnectableLayer* layer,
+                        const armnn::LstmDescriptor& descriptor,
+                        const armnn::LstmInputParams& params,
+                        const char* name = nullptr) override;
+
     void VisitMeanLayer(const armnn::IConnectableLayer* layer,
                         const armnn::MeanDescriptor& descriptor,
                         const char* name) override;
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index 4e127b3..7686d5c 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -21,6 +21,7 @@
 * Gather
 * Greater
 * L2Normalization
+* Lstm
 * Maximum
 * Mean
 * Merger
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index f40c02d..e3ce6d2 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -2047,4 +2047,379 @@
     deserializedNetwork->Accept(verifier);
 }
 
+class VerifyLstmLayer : public LayerVerifierBase
+{
+public:
+    VerifyLstmLayer(const std::string& layerName,
+                    const std::vector<armnn::TensorInfo>& inputInfos,
+                    const std::vector<armnn::TensorInfo>& outputInfos,
+                    const armnn::LstmDescriptor& descriptor,
+                    const armnn::LstmInputParams& inputParams) :
+         LayerVerifierBase(layerName, inputInfos, outputInfos), m_Descriptor(descriptor), m_InputParams(inputParams)
+    {
+    }
+    void VisitLstmLayer(const armnn::IConnectableLayer* layer,
+                        const armnn::LstmDescriptor& descriptor,
+                        const armnn::LstmInputParams& params,
+                        const char* name)
+    {
+        VerifyNameAndConnections(layer, name);
+        VerifyDescriptor(descriptor);
+        VerifyInputParameters(params);
+    }
+protected:
+    void VerifyDescriptor(const armnn::LstmDescriptor& descriptor)
+    {
+        BOOST_TEST(m_Descriptor.m_ActivationFunc == descriptor.m_ActivationFunc);
+        BOOST_TEST(m_Descriptor.m_ClippingThresCell == descriptor.m_ClippingThresCell);
+        BOOST_TEST(m_Descriptor.m_ClippingThresProj == descriptor.m_ClippingThresProj);
+        BOOST_TEST(m_Descriptor.m_CifgEnabled == descriptor.m_CifgEnabled);
+        BOOST_TEST(m_Descriptor.m_PeepholeEnabled = descriptor.m_PeepholeEnabled);
+        BOOST_TEST(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled);
+    }
+    void VerifyInputParameters(const armnn::LstmInputParams& params)
+    {
+        VerifyConstTensors(
+            "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights);
+        VerifyConstTensors(
+            "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights);
+        VerifyConstTensors(
+            "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights);
+        VerifyConstTensors(
+            "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights);
+        VerifyConstTensors(
+            "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights);
+        VerifyConstTensors(
+            "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights);
+        VerifyConstTensors(
+            "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights);
+        VerifyConstTensors(
+            "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights);
+        VerifyConstTensors(
+            "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights);
+        VerifyConstTensors(
+            "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights);
+        VerifyConstTensors(
+            "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights);
+        VerifyConstTensors(
+            "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias);
+        VerifyConstTensors(
+            "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias);
+        VerifyConstTensors(
+            "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias);
+        VerifyConstTensors(
+            "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias);
+        VerifyConstTensors(
+            "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights);
+        VerifyConstTensors(
+            "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias);
+    }
+    void VerifyConstTensors(const std::string& tensorName,
+                            const armnn::ConstTensor* expectedPtr,
+                            const armnn::ConstTensor* actualPtr)
+    {
+        if (expectedPtr == nullptr)
+        {
+            BOOST_CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist");
+        }
+        else
+        {
+            BOOST_CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set");
+            if (actualPtr != nullptr)
+            {
+                const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
+                const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
+
+                BOOST_CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
+                                    tensorName + " shapes don't match");
+                BOOST_CHECK_MESSAGE(
+                    GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
+                    tensorName + " data types don't match");
+
+                BOOST_CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
+                                    tensorName + " (GetNumBytes) data sizes do not match");
+                if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
+                {
+                    //check the data is identical
+                    const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
+                    const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
+                    bool same = true;
+                    for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
+                    {
+                        same = expectedData[i] == actualData[i];
+                        if (!same)
+                        {
+                            break;
+                        }
+                    }
+                    BOOST_CHECK_MESSAGE(same, tensorName + " data does not match");
+                }
+            }
+        }
+    }
+private:
+    armnn::LstmDescriptor m_Descriptor;
+    armnn::LstmInputParams m_InputParams;
+};
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection)
+{
+    armnn::LstmDescriptor descriptor;
+    descriptor.m_ActivationFunc = 4;
+    descriptor.m_ClippingThresProj = 0.0f;
+    descriptor.m_ClippingThresCell = 0.0f;
+    descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams
+    descriptor.m_ProjectionEnabled = false;
+    descriptor.m_PeepholeEnabled = true;
+
+    const uint32_t batchSize = 1;
+    const uint32_t inputSize = 2;
+    const uint32_t numUnits = 4;
+    const uint32_t outputSize = numUnits;
+
+    armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32);
+    std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+    armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData);
+
+    std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+    armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData);
+
+    std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+    armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData);
+
+    armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32);
+    std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+    armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData);
+
+    std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+    armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData);
+
+    std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+    armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData);
+
+    armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32);
+    std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+    armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData);
+
+    std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+    armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData);
+
+    std::vector<float> forgetGateBiasData(numUnits, 1.0f);
+    armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData);
+
+    std::vector<float> cellBiasData(numUnits, 0.0f);
+    armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData);
+
+    std::vector<float> outputGateBiasData(numUnits, 0.0f);
+    armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData);
+
+    armnn::LstmInputParams params;
+    params.m_InputToForgetWeights = &inputToForgetWeights;
+    params.m_InputToCellWeights = &inputToCellWeights;
+    params.m_InputToOutputWeights = &inputToOutputWeights;
+    params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+    params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+    params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+    params.m_ForgetGateBias = &forgetGateBias;
+    params.m_CellBias = &cellBias;
+    params.m_OutputGateBias = &outputGateBias;
+    params.m_CellToForgetWeights = &cellToForgetWeights;
+    params.m_CellToOutputWeights = &cellToOutputWeights;
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+    armnn::IConnectableLayer* const inputLayer   = network->AddInputLayer(0);
+    armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+    armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+    const std::string layerName("lstm");
+    armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str());
+    armnn::IConnectableLayer* const scratchBuffer  = network->AddOutputLayer(0);
+    armnn::IConnectableLayer* const outputStateOut  = network->AddOutputLayer(1);
+    armnn::IConnectableLayer* const cellStateOut  = network->AddOutputLayer(2);
+    armnn::IConnectableLayer* const outputLayer  = network->AddOutputLayer(3);
+
+    // connect up
+    armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32);
+    armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+    armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+    armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 3 }, armnn::DataType::Float32);
+
+    inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0));
+    inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+    outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1));
+    outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+    cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2));
+    cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+    lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff);
+
+    lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo);
+
+    lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo);
+
+    lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
+
+    VerifyLstmLayer checker(
+        layerName,
+        {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+        {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo},
+        descriptor,
+        params);
+    deserializedNetwork->Accept(checker);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection)
+{
+    armnn::LstmDescriptor descriptor;
+    descriptor.m_ActivationFunc = 4;
+    descriptor.m_ClippingThresProj = 0.0f;
+    descriptor.m_ClippingThresCell = 0.0f;
+    descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams
+    descriptor.m_ProjectionEnabled = true;
+    descriptor.m_PeepholeEnabled = true;
+
+    const uint32_t batchSize = 2;
+    const uint32_t inputSize = 5;
+    const uint32_t numUnits = 20;
+    const uint32_t outputSize = 16;
+
+    armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32);
+    std::vector<float> inputToInputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+    armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData);
+
+    std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+    armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData);
+
+    std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+    armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData);
+
+    std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+    armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData);
+
+    armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32);
+    std::vector<float> inputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+    armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData);
+
+    std::vector<float> forgetGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+    armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData);
+
+    std::vector<float> cellBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+    armnn::ConstTensor cellBias(tensorInfo20, cellBiasData);
+
+    std::vector<float> outputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+    armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData);
+
+    armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32);
+    std::vector<float> recurrentToInputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+    armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData);
+
+    std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+    armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData);
+
+    std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+    armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData);
+
+    std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+    armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData);
+
+    std::vector<float> cellToInputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+    armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData);
+
+    std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+    armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData);
+
+    std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+    armnn::ConstTensor cellToOutputWeights(tensorInfo20,  cellToOutputWeightsData);
+
+    armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32);
+    std::vector<float> projectionWeightsData = GenerateRandomData<float>(tensorInfo16x20.GetNumElements());
+    armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData);
+
+    armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32);
+    std::vector<float> projectionBiasData(outputSize, 0.f);
+    armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData);
+
+    armnn::LstmInputParams params;
+    params.m_InputToForgetWeights = &inputToForgetWeights;
+    params.m_InputToCellWeights = &inputToCellWeights;
+    params.m_InputToOutputWeights = &inputToOutputWeights;
+    params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+    params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+    params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+    params.m_ForgetGateBias = &forgetGateBias;
+    params.m_CellBias = &cellBias;
+    params.m_OutputGateBias = &outputGateBias;
+
+    // additional params because: descriptor.m_CifgEnabled = false
+    params.m_InputToInputWeights = &inputToInputWeights;
+    params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+    params.m_CellToInputWeights = &cellToInputWeights;
+    params.m_InputGateBias = &inputGateBias;
+
+    // additional params because: descriptor.m_ProjectionEnabled = true
+    params.m_ProjectionWeights = &projectionWeights;
+    params.m_ProjectionBias = &projectionBias;
+
+    // additional params because: descriptor.m_PeepholeEnabled = true
+    params.m_CellToForgetWeights = &cellToForgetWeights;
+    params.m_CellToOutputWeights = &cellToOutputWeights;
+
+    armnn::INetworkPtr network = armnn::INetwork::Create();
+    armnn::IConnectableLayer* const inputLayer   = network->AddInputLayer(0);
+    armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+    armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+    const std::string layerName("lstm");
+    armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str());
+    armnn::IConnectableLayer* const scratchBuffer  = network->AddOutputLayer(0);
+    armnn::IConnectableLayer* const outputStateOut  = network->AddOutputLayer(1);
+    armnn::IConnectableLayer* const cellStateOut  = network->AddOutputLayer(2);
+    armnn::IConnectableLayer* const outputLayer  = network->AddOutputLayer(3);
+
+    // connect up
+    armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32);
+    armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+    armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+    armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 4 }, armnn::DataType::Float32);
+
+    inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0));
+    inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+    outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1));
+    outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+    cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2));
+    cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+    lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff);
+
+    lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo);
+
+    lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo);
+
+    lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0));
+    lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
+
+    armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+    BOOST_CHECK(deserializedNetwork);
+
+    VerifyLstmLayer checker(
+        layerName,
+        {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+        {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo},
+        descriptor,
+        params);
+    deserializedNetwork->Accept(checker);
+}
+
 BOOST_AUTO_TEST_SUITE_END()