COMPMID-1188: Fixes LSTM IO dimension requirements.

Change-Id: Iee92ccce6422368c19173174e6f58e7aada12233
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140143
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 86e5eb9..8723251 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -295,27 +295,27 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
                                                        recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state);
-    ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
-    ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() != 1);
-    ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() != 1);
-    ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() != 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() != 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
+    ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
+    ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
+    ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
     ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
 
     if(lstm_params.has_peephole_opt())
     {
         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
-        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
-        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() != 1);
+        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
+        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
     }
 
     TensorShape      units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
@@ -340,10 +340,10 @@
     if(!lstm_params.has_cifg_opt())
     {
         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.cell_to_input_weights(), lstm_params.input_gate_bias());
-        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() != 2);
-        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() != 2);
-        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() != 1);
-        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() != 1);
+        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
+        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
+        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
+        ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
         ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo()));
         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE));