COMPMID-1124 : Fixes in CLLSTM layer

Change-Id: Ifc8e12c296d3ef2bf8e0f0bf1b87b7fd47a1fad7
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139248
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Ruomei Yan <ruomei.yan@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index bff2f37..3c1a560 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -315,9 +315,8 @@
 
         if(peephole_opt)
         {
-            transposed_weights = reference::transpose(cell_to_forget_w);
-            gemm               = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
-            forget_gate        = reference::arithmetic_addition(forget_gate, gemm, data_type, ConvertPolicy::SATURATE);
+            SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+            forget_gate                               = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
         }
 
         forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
@@ -332,14 +331,13 @@
         }
         else
         {
-            SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
-            transposed_weights                    = reference::transpose(recurrent_to_input_w);
-            gemm                                  = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
-            input_gate                            = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
-            transposed_weights                    = reference::transpose(cell_to_input_w);
-            gemm                                  = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
-            input_gate                            = reference::arithmetic_addition(input_gate, gemm, data_type, ConvertPolicy::SATURATE);
-            input_gate                            = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+            SimpleTensor<T> fully_connected_input    = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
+            transposed_weights                       = reference::transpose(recurrent_to_input_w);
+            gemm                                     = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
+            input_gate                               = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
+            SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+            input_gate                               = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
+            input_gate                               = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
         }
 
         // Compute cell_state
@@ -363,9 +361,8 @@
         output                                 = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
         if(peephole_opt)
         {
-            transposed_weights = reference::transpose(cell_to_output_w);
-            gemm               = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
-            output             = reference::arithmetic_addition(output, gemm, data_type, ConvertPolicy::SATURATE);
+            pixelwise_mul = reference::pixel_wise_multiplication(cell_state, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+            output        = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
         }
         output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));