IVGCVSW-3396 Support joined lstm parameters

!armnn:1470

Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com>
Change-Id: I67a393c1556f0b3022436e41f82f2bf1ab3a1d40
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index 13c9327..9673a74 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -874,50 +874,41 @@
     const armnn::TensorInfo& outputInfo         = GetTensorInfoForOperand(*output);
 
     // Basic parameters
-    const armnn::TensorInfo& inputToForgetWeights = params.m_InputToForgetWeights->GetInfo();
-    const armnn::TensorInfo& inputToCellWeights   = params.m_InputToCellWeights->GetInfo();
-    const armnn::TensorInfo& inputToOutputWeights = params.m_InputToOutputWeights->GetInfo();
-    const armnn::TensorInfo& recurrentToForgetWeights = params.m_RecurrentToForgetWeights->GetInfo();
-    const armnn::TensorInfo& recurrentToCellWeights = params.m_RecurrentToCellWeights->GetInfo();
-    const armnn::TensorInfo& recurrentToOutputWeights = params.m_RecurrentToOutputWeights->GetInfo();
-    const armnn::TensorInfo& forgetGateBias = params.m_ForgetGateBias->GetInfo();
-    const armnn::TensorInfo& cellBias = params.m_CellBias->GetInfo();
-    const armnn::TensorInfo& outputGateBias = params.m_OutputGateBias->GetInfo();
-
-    //Optional parameters
-    const armnn::TensorInfo* inputToInputWeights = nullptr;
-    const armnn::TensorInfo* recurrentToInputWeights = nullptr;
-    const armnn::TensorInfo* cellToInputWeights = nullptr;
-    const armnn::TensorInfo* inputGateBias = nullptr;
-    const armnn::TensorInfo* projectionWeights = nullptr;
-    const armnn::TensorInfo* projectionBias    = nullptr;
-    const armnn::TensorInfo* cellToForgetWeights = nullptr;
-    const armnn::TensorInfo* cellToOutputWeights = nullptr;
+    armnn::LstmInputParamsInfo paramsInfo;
+    paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
+    paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
+    paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
+    paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
+    paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
+    paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
+    paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
+    paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
+    paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
 
     if(!desc.m_CifgEnabled)
     {
-        inputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
-        recurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
+        paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
+        paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
         if (params.m_CellToInputWeights != nullptr)
         {
-            cellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
+            paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
         }
-        inputGateBias = &(params.m_InputGateBias->GetInfo());
+        paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
     }
 
     if(desc.m_ProjectionEnabled)
     {
-        projectionWeights = &(params.m_ProjectionWeights->GetInfo());
+        paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
         if (params.m_ProjectionBias != nullptr)
         {
-            projectionBias = &(params.m_ProjectionBias->GetInfo());
+            paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
         }
     }
 
     if(desc.m_PeepholeEnabled)
     {
-        cellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
-        cellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
+        paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
+        paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
     }
 
     if (!IsLayerSupportedForAnyBackend(__func__,
@@ -931,23 +922,7 @@
                                        cellStateOutInfo,
                                        outputInfo,
                                        desc,
-                                       inputToForgetWeights,
-                                       inputToCellWeights,
-                                       inputToOutputWeights,
-                                       recurrentToForgetWeights,
-                                       recurrentToCellWeights,
-                                       recurrentToOutputWeights,
-                                       forgetGateBias,
-                                       cellBias,
-                                       outputGateBias,
-                                       inputToInputWeights,
-                                       recurrentToInputWeights,
-                                       cellToInputWeights,
-                                       inputGateBias,
-                                       projectionWeights,
-                                       projectionBias,
-                                       cellToForgetWeights,
-                                       cellToOutputWeights))
+                                       paramsInfo))
     {
         return false;
     }