IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported

!android-nn-driver:1461

Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
index f4d8974..3dbbbc3 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
@@ -224,22 +224,7 @@
                                                 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
                                                 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
                                                 const TensorInfo& output, const LstmDescriptor& descriptor,
-                                                const TensorInfo& inputToForgetWeights,
-                                                const TensorInfo& inputToCellWeights,
-                                                const TensorInfo& inputToOutputWeights,
-                                                const TensorInfo& recurrentToForgetWeights,
-                                                const TensorInfo& recurrentToCellWeights,
-                                                const TensorInfo& recurrentToOutputWeights,
-                                                const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
-                                                const TensorInfo& outputGateBias,
-                                                const TensorInfo* inputToInputWeights,
-                                                const TensorInfo* recurrentToInputWeights,
-                                                const TensorInfo* cellToInputWeights,
-                                                const TensorInfo* inputGateBias,
-                                                const TensorInfo* projectionWeights,
-                                                const TensorInfo* projectionBias,
-                                                const TensorInfo* cellToForgetWeights,
-                                                const TensorInfo* cellToOutputWeights)
+                                                const LstmInputParamsInfo& paramsInfo)
 {
     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
 
@@ -253,18 +238,21 @@
     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
 
     // Basic parameters
-    const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
-    const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
-    const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
+    const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+    const arm_compute::TensorInfo aclInputToCellWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+    const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToForgetWeights);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToCellWeights);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToOutputWeights);
-    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
-    const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
-    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+    const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
 
     arm_compute::TensorInfo aclInputToInputWeightsInfo;
     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -277,43 +265,37 @@
 
     if (!descriptor.m_CifgEnabled)
     {
-        armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
-        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
-        armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
-        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
+        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
 
-        if (cellToInputWeights != nullptr)
+        if (paramsInfo.m_CellToInputWeights != nullptr)
         {
-            armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
-            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
+            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
         }
-        armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
-        aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
+        aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
-                                         cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
+                                         paramsInfo.m_CellToInputWeights != nullptr ?
+                                         &aclCellToInputWeightsInfo: nullptr,
                                          &aclInputGateBiasInfo);
     }
 
     if (descriptor.m_ProjectionEnabled)
     {
-        const armnn::TensorInfo& projectionWInfo = *projectionWeights;
-        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
+        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
 
-        if (projectionBias != nullptr)
+        if (paramsInfo.m_ProjectionBias != nullptr)
         {
-            const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
-            aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
+            aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
         }
         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
-                                               projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
+                                               paramsInfo.m_ProjectionBias != nullptr ?
+                                               &aclProjectionBiasInfo: nullptr);
     }
 
     if (descriptor.m_PeepholeEnabled)
     {
-        const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
-        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
-        const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
-        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
+        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
+        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
     }