IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported

!android-nn-driver:1461

Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index ac7f310..59c14c4 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -924,51 +924,11 @@
                                       const TensorInfo& cellStateOut,
                                       const TensorInfo& output,
                                       const LstmDescriptor& descriptor,
-                                      const TensorInfo& inputToForgetWeights,
-                                      const TensorInfo& inputToCellWeights,
-                                      const TensorInfo& inputToOutputWeights,
-                                      const TensorInfo& recurrentToForgetWeights,
-                                      const TensorInfo& recurrentToCellWeights,
-                                      const TensorInfo& recurrentToOutputWeights,
-                                      const TensorInfo& forgetGateBias,
-                                      const TensorInfo& cellBias,
-                                      const TensorInfo& outputGateBias,
-                                      const TensorInfo* inputToInputWeights,
-                                      const TensorInfo* recurrentToInputWeights,
-                                      const TensorInfo* cellToInputWeights,
-                                      const TensorInfo* inputGateBias,
-                                      const TensorInfo* projectionWeights,
-                                      const TensorInfo* projectionBias,
-                                      const TensorInfo* cellToForgetWeights,
-                                      const TensorInfo* cellToOutputWeights,
-                                      Optional<std::string&> reasonIfUnsupported,
-                                      const TensorInfo* inputLayerNormWeights,
-                                      const TensorInfo* forgetLayerNormWeights,
-                                      const TensorInfo* cellLayerNormWeights,
-                                      const TensorInfo* outputLayerNormWeights) const
+                                      const LstmInputParamsInfo& paramsInfo,
+                                      Optional<std::string&> reasonIfUnsupported) const
 {
     ignore_unused(descriptor);
-    ignore_unused(inputToForgetWeights);
-    ignore_unused(inputToCellWeights);
-    ignore_unused(inputToOutputWeights);
-    ignore_unused(recurrentToForgetWeights);
-    ignore_unused(recurrentToCellWeights);
-    ignore_unused(recurrentToOutputWeights);
-    ignore_unused(forgetGateBias);
-    ignore_unused(cellBias);
-    ignore_unused(outputGateBias);
-    ignore_unused(inputToInputWeights);
-    ignore_unused(recurrentToInputWeights);
-    ignore_unused(cellToInputWeights);
-    ignore_unused(inputGateBias);
-    ignore_unused(projectionWeights);
-    ignore_unused(projectionBias);
-    ignore_unused(cellToForgetWeights);
-    ignore_unused(cellToOutputWeights);
-    ignore_unused(inputLayerNormWeights);
-    ignore_unused(forgetLayerNormWeights);
-    ignore_unused(cellLayerNormWeights);
-    ignore_unused(outputLayerNormWeights);
+    ignore_unused(paramsInfo);
 
     bool supported = true;
 
@@ -977,26 +937,91 @@
         DataType::QuantisedSymm16
     };
 
+    // check inputs and outputs
     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
                                   "Reference Lstm: input is not a supported type.");
-
     supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
                                   "Reference Lstm: input and outputStateIn types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
                                   "Reference Lstm: input and cellStateIn types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
                                   "Reference Lstm: input and scratchBuffer types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
                                   "Reference Lstm: input and outputStateOut types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
                                   "Reference Lstm: input and cellStateOut types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
                                   "Reference Lstm: input and output types are mismatched");
+    // check layer parameters
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and InputToForgetWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and InputToCellWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and InputToOutputWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
+                                  "Reference Lstm: input and ForgetGateBias types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
+                                  "Reference Lstm: input and CellBias types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
+                                  "Reference Lstm: input and OutputGateBias types are mismatched");
+    if (!descriptor.m_CifgEnabled)
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and InputToInputWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
+                                      "Reference Lstm: input and InputGateBias types are mismatched");
+        if (descriptor.m_PeepholeEnabled)
+        {
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
+                                          reasonIfUnsupported,
+                                          "Reference Lstm: input and CellToInputWeights types are mismatched");
+        }
+    }
+    if (descriptor.m_PeepholeEnabled)
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and CellToForgetWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and CellToOutputWeights types are mismatched");
+    }
+    if (descriptor.m_ProjectionEnabled)
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and mProjectionWeights types are mismatched");
+        if (paramsInfo.m_ProjectionBias != nullptr)
+        {
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
+                                          "Reference Lstm: input and ProjectionBias types are mismatched");
+        }
+    }
+    if (descriptor.m_LayerNormEnabled)
+    {
+        if (!descriptor.m_CifgEnabled)
+        {
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
+                                          reasonIfUnsupported,
+                                          "Reference Lstm: input and InputLayerNormWeights types are mismatched");
+        }
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and CellLayerNormWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
+    }
 
     return supported;
 }