COMPMID-3470: Modify NE/CLQLSTMLayer interface to provide 3 outputs

Change-Id: I895b697c89c9a7509d48a54ac1effb7fbd8cca19
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3174
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index d9b5c7c..a20ffc6 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -76,12 +76,12 @@
                              const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                              const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
                              const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
-                             ICLTensor *cell_state_out, ICLTensor *output_state_out,
+                             ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
                              const LSTMParams<ICLTensor> &lstm_params)
 {
     configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
               recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
-              cell_state_in, output_state_in, cell_state_out, output_state_out, lstm_params);
+              cell_state_in, output_state_in, cell_state_out, output, output_state_out, lstm_params);
 }
 
 void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
@@ -89,12 +89,13 @@
                              const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                              const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
                              const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
-                             ICLTensor *cell_state_out, ICLTensor *output_state_out,
+                             ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
                              const LSTMParams<ICLTensor> &lstm_params)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
-                                 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
+                                 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
+                                 cell_state_out, output_state_out, output);
 
     // Set lstm parameters
     LSTMParams<ITensorInfo> lstm_params_info{};
@@ -104,7 +105,8 @@
     ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
                                                       recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
                                                       forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
-                                                      cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), lstm_params_info));
+                                                      cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
+                                                      lstm_params_info));
 
     const int batch_size = input->info()->dimension(1);
     const int num_units  = input_to_output_weights->info()->dimension(1);
@@ -446,6 +448,9 @@
             _has_projection_clipping = true;
         }
     }
+
+    // Copy output_state_out to output
+    _copy_output.configure(compile_context, output_state_out, output);
 }
 
 Status CLQLSTMLayer::validate(const ITensorInfo *input,
@@ -453,11 +458,12 @@
                               const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
                               const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
                               const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
-                              const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out,
+                              const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
                               const LSTMParams<ITensorInfo> &lstm_params)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
-                                        recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
+                                        recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
+                                        cell_state_out, output_state_out, output);
 
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
@@ -768,6 +774,7 @@
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
     }
 
+    ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
     return Status{};
 }
 
@@ -887,6 +894,9 @@
             _projection_clip.run();
         }
     }
+
+    // Copy output_state_out to output
+    CLScheduler::get().enqueue(_copy_output);
 }
 
 void CLQLSTMLayer::prepare()