Fix incorrect validation of Unidirectional Sequence LSTM on Cl and Neon

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I54c60fb98b9c560c300572f46d42b13aec7e402e
diff --git a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp
index 289442e..fb31d7c 100644
--- a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -508,17 +508,21 @@
 ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
                                                   const TensorInfo& outputStateIn,
                                                   const TensorInfo& cellStateIn,
+                                                  const TensorInfo& outputStateOut,
+                                                  const TensorInfo& cellStateOut,
                                                   const TensorInfo& output,
-                                                  const Optional<TensorInfo>& hiddenStateOutput,
-                                                  const Optional<TensorInfo>& cellStateOutput,
                                                   const UnidirectionalSequenceLstmDescriptor& descriptor,
                                                   const LstmInputParamsInfo& paramsInfo)
 {
-    IgnoreUnused(hiddenStateOutput, cellStateOutput);
-
     TensorShape inputLayerShape  = input.GetShape();
     TensorShape outputLayerShape = output.GetShape();
 
+    if (inputLayerShape.GetNumDimensions() != 3)
+    {
+        return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
+                                   "Unidirectional Sequence LSTM layer validate status failed.");
+    }
+
     unsigned int maxTime    = descriptor.m_TimeMajor?inputLayerShape[0]:inputLayerShape[1];
     unsigned int batchSize  = descriptor.m_TimeMajor?inputLayerShape[1]:inputLayerShape[0];
     unsigned int inputSize  = inputLayerShape[2];
@@ -544,7 +548,7 @@
     //
     // Permute validate
     //
-    TensorInfo              permuteOutInfo    = TensorInfo(input);
+    TensorInfo              permuteOutInfo    = armnnUtils::Permuted(input, { 1U, 0U, 2U });
     arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo);
     if (!descriptor.m_TimeMajor)
     {
@@ -610,9 +614,16 @@
 
     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
 
-    const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
-    const TensorInfo& outputStateOut = TensorInfo(outputStateIn.GetShape(), input.GetDataType());
-    const TensorInfo& cellStateOut = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
+    unsigned int numUnits = cellStateIn.GetShape()[1];
+    unsigned int scratchBufferFactor = 4;
+
+    if (descriptor.m_CifgEnabled)
+    {
+        // scratchBuffer = { batchSize, numUnits * 3 } with CIFG
+       scratchBufferFactor = 3;
+    }
+
+    const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType());
 
     // The inputs and outputs
     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);