MLCE-530 Add support for UnidirectionalSequenceLstm to RefWorkload

 * Add implementation of IsUnidirectionalSequenceLstmSupported to RefLayerSupport
 * Add RefUnidirectionalSequenceLstmWorkload
 * Refactor Lstm to be able to use for Lstm and SequenceLstm
 * Unit tests

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ibc066d213213a11b955dfefbe518de643298ba0c
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 1b05c4e..2603371 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1242,6 +1242,7 @@
                                   "Reference Lstm: input and outputStateOut types are mismatched");
     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
                                   "Reference Lstm: input and cellStateOut types are mismatched");
+
     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
                                   "Reference Lstm: input and output types are mismatched");
     // check layer parameters
@@ -2288,4 +2289,150 @@
     return supported;
 }
 
+bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
+        const TensorInfo& input,
+        const TensorInfo& outputStateIn,
+        const TensorInfo& cellStateIn,
+        const TensorInfo& output,
+        const Optional<TensorInfo>& hiddenStateOutput,
+        const Optional<TensorInfo>& cellStateOutput,
+        const UnidirectionalSequenceLstmDescriptor& descriptor,
+        const LstmInputParamsInfo& paramsInfo,
+        Optional<std::string&> reasonIfUnsupported) const
+{
+    IgnoreUnused(descriptor);
+    IgnoreUnused(paramsInfo);
+    IgnoreUnused(outputStateIn);
+    IgnoreUnused(cellStateIn);
+    bool supported = true;
+
+    if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
+    {
+        reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
+                                       "and cell state output are not supported at the moment.";
+    }
+
+    std::array<DataType, 1> supportedTypes =
+    {
+        DataType::Float32
+    };
+
+    std::array<DataType, 1> supportedWeightTypes =
+    {
+        DataType::Float32
+    };
+
+    // check inputs and outputs
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: input is not a supported type.");
+    supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
+    // check layer parameters
+    supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
+                                  reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
+                                  "is not a supported type.");
+    supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
+                                  reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
+    supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
+                                  reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
+                                  "is not a supported type.");
+    supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
+                                  reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
+                                  "is not a supported type.");
+    supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
+                                  reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
+                                  "is not a supported type.");
+    supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
+                                  reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
+                                  "is not a supported type.");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
+                                  "are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
+                                  "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
+                                  "are mismatched");
+    if (!descriptor.m_CifgEnabled)
+    {
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: InputToInputWeights "
+                                      "is not a supported type.");
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
+                                      "is not a supported type.");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
+                                      "are mismatched");
+        if (descriptor.m_PeepholeEnabled)
+        {
+            supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
+                                          reasonIfUnsupported,
+                                          "Reference UnidirectionalSequenceLstm: CellToInputWeights "
+                                          "is not a supported type.");
+        }
+    }
+    if (descriptor.m_PeepholeEnabled)
+    {
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
+                                      "is not a supported type.");
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
+                                      "is not a supported type.");
+    }
+    if (descriptor.m_ProjectionEnabled)
+    {
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: ProjectionWeights "
+                                      "is not a supported type.");
+        if (paramsInfo.m_ProjectionBias != nullptr)
+        {
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
+                                          "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
+                                          "are mismatched");
+        }
+    }
+    if (descriptor.m_LayerNormEnabled)
+    {
+        if (!descriptor.m_CifgEnabled)
+        {
+            supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
+                                          reasonIfUnsupported,
+                                          "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
+                                          "is not a supported type.");
+        }
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
+                                      "is not a supported type.");
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
+                                      "is not a supported type.");
+        supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
+                                      reasonIfUnsupported,
+                                      "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
+                                      "is not a supported type.");
+    }
+
+    return supported;
+}
+
 } // namespace armnn