COMPMID-3385: Async support to CLArithmetic* kernels/functions Pt.2

Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Change-Id: Idc5ac2dd2ba5295c00c88b44a783645327a27e15
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3617
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h b/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h
index bb98eb8..86159fc 100644
--- a/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h
+++ b/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h
@@ -62,16 +62,16 @@
      *   - (QSYMM16,QSYMM16)               -> QSYMM16
      *   - (QSYMM16,QSYMM16)               -> S32
      *
-     * @param[in]  input1          An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
-     * @param[in]  input2          An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
-     * @param[out] output          The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in]  input1          An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in]  input2          An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[out] output          The output tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
      * @param[in]  scale           Scale to apply after multiplication.
      *                             Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
      * @param[in]  overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
      * @param[in]  rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
      * @param[in]  act_info        (Optional) Activation layer information in case of a fused activation.
      */
-    void configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
+    void configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale,
                    ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
     /** Initialise the kernel's input, output and border mode.
      *
@@ -90,16 +90,16 @@
      *   - (QSYMM16,QSYMM16)               -> S32
      *
      * @param[in]  compile_context The compile context to be used.
-     * @param[in]  input1          An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
-     * @param[in]  input2          An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
-     * @param[out] output          The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in]  input1          An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in]  input2          An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[out] output          The output tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
      * @param[in]  scale           Scale to apply after multiplication.
      *                             Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
      * @param[in]  overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
      * @param[in]  rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
      * @param[in]  act_info        (Optional) Activation layer information in case of a fused activation.
      */
-    void configure(const CLCompileContext &compile_context, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
+    void configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale,
                    ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref CLPixelWiseMultiplicationKernel
      *
@@ -132,13 +132,13 @@
                            ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
 
     // Inherited methods overridden:
-    void run(const Window &window, cl::CommandQueue &queue) override;
+    void run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, cl::CommandQueue &queue) override;
     BorderSize border_size() const override;
 
 private:
-    const ICLTensor *_input1;
-    const ICLTensor *_input2;
-    ICLTensor       *_output;
+    const ITensorInfo *_input1;
+    const ITensorInfo *_input2;
+    ITensorInfo       *_output;
 };
 
 /** Interface for the complex pixelwise multiplication kernel. */
@@ -157,21 +157,21 @@
     CLComplexPixelWiseMultiplicationKernel &operator=(CLComplexPixelWiseMultiplicationKernel &&) = default;
     /** Initialise the kernel's input, output and border mode.
      *
-     * @param[in]  input1   An input tensor. Data types supported: F32. Number of channels supported: 2.
-     * @param[in]  input2   An input tensor. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
-     * @param[out] output   The output tensor, Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     * @param[in]  input1   An input tensor info. Data types supported: F32. Number of channels supported: 2.
+     * @param[in]  input2   An input tensor info. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     * @param[out] output   The output tensor info. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
      * @param[in]  act_info (Optional) Activation layer information in case of a fused activation.
      */
-    void configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+    void configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
     /** Initialise the kernel's input, output and border mode.
      *
      * @param[in]  compile_context The compile context to be used.
-     * @param[in]  input1          An input tensor. Data types supported: F32. Number of channels supported: 2.
-     * @param[in]  input2          An input tensor. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
-     * @param[out] output          The output tensor, Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     * @param[in]  input1          An input tensor info. Data types supported: F32. Number of channels supported: 2.
+     * @param[in]  input2          An input tensor info. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     * @param[out] output          The output tensor info. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
      * @param[in]  act_info        (Optional) Activation layer information in case of a fused activation.
      */
-    void configure(const CLCompileContext &compile_context, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+    void configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref CLComplexPixelWiseMultiplicationKernel
      *
      * @param[in] input1   An input tensor info. Data types supported: F32. Number of channels supported: 2.
@@ -184,13 +184,13 @@
     static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
 
     // Inherited methods overridden:
-    void run(const Window &window, cl::CommandQueue &queue) override;
+    void run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, cl::CommandQueue &queue) override;
     BorderSize border_size() const override;
 
 private:
-    const ICLTensor *_input1;
-    const ICLTensor *_input2;
-    ICLTensor       *_output;
+    const ITensorInfo *_input1;
+    const ITensorInfo *_input2;
+    ITensorInfo       *_output;
 };
 } // namespace arm_compute
 #endif /*ARM_COMPUTE_CLPIXELWISEMULTIPLICATIONKERNEL_H */
diff --git a/arm_compute/core/utils/misc/InfoHelpers.h b/arm_compute/core/utils/misc/InfoHelpers.h
index ffde82b..ced0d24 100644
--- a/arm_compute/core/utils/misc/InfoHelpers.h
+++ b/arm_compute/core/utils/misc/InfoHelpers.h
@@ -86,7 +86,7 @@
     {
         ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
 
-        const ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr;
+        ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr;
         lstm_params_info->set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(),
                                           cell_to_input_weights_info, lstm_params.input_gate_bias()->info());
     }
@@ -100,10 +100,10 @@
             ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
         }
 
-        const ITensorInfo *forget_info = lstm_params.forget_layer_norm_weights()->info();
-        const ITensorInfo *cell_info   = lstm_params.cell_layer_norm_weights()->info();
-        const ITensorInfo *output_info = lstm_params.output_layer_norm_weights()->info();
-        const ITensorInfo *input_info  = lstm_params.has_cifg_opt() ? nullptr : lstm_params.input_layer_norm_weights()->info();
+        ITensorInfo *forget_info = lstm_params.forget_layer_norm_weights()->info();
+        ITensorInfo *cell_info   = lstm_params.cell_layer_norm_weights()->info();
+        ITensorInfo *output_info = lstm_params.output_layer_norm_weights()->info();
+        ITensorInfo *input_info  = lstm_params.has_cifg_opt() ? nullptr : lstm_params.input_layer_norm_weights()->info();
 
         lstm_params_info->set_layer_normalization_params(input_info, forget_info, cell_info, output_info);
     }
diff --git a/arm_compute/runtime/CL/functions/CLLSTMLayer.h b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
index abfcc3a..1a8b334 100644
--- a/arm_compute/runtime/CL/functions/CLLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
@@ -28,7 +28,6 @@
 
 #include "arm_compute/core/CL/kernels/CLCopyKernel.h"
 #include "arm_compute/core/CL/kernels/CLMemsetKernel.h"
-#include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/CL/CLTensor.h"
 #include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
@@ -37,6 +36,7 @@
 #include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h"
 #include "arm_compute/runtime/CL/functions/CLGEMM.h"
 #include "arm_compute/runtime/CL/functions/CLMeanStdDevNormalizationLayer.h"
+#include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h"
 #include "arm_compute/runtime/IMemoryManager.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 #include "arm_compute/runtime/common/LSTMParams.h"
@@ -97,7 +97,7 @@
                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                   const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
+                   const ICLTensor *output_state_in, ICLTensor *cell_state_in,
                    ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
                    const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
     /** Initialize function's tensors.
@@ -143,7 +143,7 @@
                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                   const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
+                   const ICLTensor *output_state_in, ICLTensor *cell_state_in,
                    ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
                    const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
 
@@ -200,90 +200,90 @@
     void prepare() override;
 
 private:
-    MemoryGroup                     _memory_group;
-    CLFullyConnectedLayer           _fully_connected_input_gate;
-    CLArithmeticAddition            _accum_input_gate1;
-    CLArithmeticSubtraction         _subtract_input_gate;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_input_gate;
-    CLActivationLayer               _activation_input_gate;
-    CLFullyConnectedLayer           _fully_connected_forget_gate;
-    CLArithmeticAddition            _accum_forget_gate1;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_forget_gate;
-    CLActivationLayer               _activation_forget_gate;
-    CLFullyConnectedLayer           _fully_connected_cell_state;
-    CLGEMM                          _gemm_cell_state1;
-    CLTransposeKernel               _transpose_cell_state;
-    CLArithmeticAddition            _accum_cell_state1;
-    CLArithmeticAddition            _accum_cell_state2;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_state1;
-    CLActivationLayer               _activation_cell_state;
-    CLActivationLayer               _cell_clip;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_state2;
-    CLFullyConnectedLayer           _fully_connected_output;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_output_state1;
-    CLArithmeticAddition            _accum_output1;
-    CLActivationLayer               _activation_output;
-    CLActivationLayer               _activation_output_state;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_output_state2;
-    CLFullyConnectedLayer           _fully_connected_output_state;
-    CLActivationLayer               _projection_clip;
-    CLCopyKernel                    _copy_cell_state;
-    CLCopyKernel                    _copy_output;
-    CLConcatenateLayer              _concat_scratch_buffer;
-    CLConcatenateLayer              _concat_inputs_forget_gate;
-    CLConcatenateLayer              _concat_weights_forget_gate;
-    CLConcatenateLayer              _concat_weights_input_gate;
-    CLConcatenateLayer              _concat_weights_output;
-    CLMemsetKernel                  _ones_memset_kernel;
-    CLMeanStdDevNormalizationLayer  _mean_std_norm_input_gate;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_input_gate_coeff;
-    CLArithmeticAddition            _accum_input_gate_bias;
-    CLMeanStdDevNormalizationLayer  _mean_std_norm_forget_gate;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_forget_gate_coeff;
-    CLArithmeticAddition            _accum_forget_gate_bias;
-    CLMeanStdDevNormalizationLayer  _mean_std_norm_cell_gate;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_gate_coeff;
-    CLArithmeticAddition            _accum_cell_gate_bias;
-    CLMeanStdDevNormalizationLayer  _mean_std_norm_output_gate;
-    CLPixelWiseMultiplicationKernel _pixelwise_mul_output_gate_coeff;
-    CLArithmeticAddition            _accum_output_gate_bias;
-    CLTensor                        _input_gate_out1;
-    CLTensor                        _input_gate_out2;
-    CLTensor                        _input_gate_out3;
-    CLTensor                        _input_gate_out4;
-    CLTensor                        _forget_gate_out1;
-    CLTensor                        _forget_gate_out2;
-    CLTensor                        _forget_gate_out3;
-    CLTensor                        _forget_gate_out4;
-    CLTensor                        _forget_gate_out5;
-    CLTensor                        _forget_gate_out6;
-    CLTensor                        _cell_state_out1;
-    CLTensor                        _cell_state_out2;
-    CLTensor                        _cell_state_out3;
-    CLTensor                        _cell_state_out4;
-    CLTensor                        _cell_state_out5;
-    CLTensor                        _output1;
-    CLTensor                        _output2;
-    CLTensor                        _output3;
-    CLTensor                        _output4;
-    CLTensor                        _cell_state_activation;
-    CLTensor                        _output_state1;
-    CLTensor                        _ones;
-    CLTensor                        _input_layer_norm_out1;
-    CLTensor                        _input_layer_norm_out2;
-    CLTensor                        _forget_layer_norm_out1;
-    CLTensor                        _forget_layer_norm_out2;
-    CLTensor                        _cell_layer_norm_out1;
-    CLTensor                        _cell_layer_norm_out2;
-    CLTensor                        _output_layer_norm_out1;
-    CLTensor                        _output_layer_norm_out2;
-    bool                            _run_peephole_opt;
-    bool                            _run_cifg_opt;
-    bool                            _perform_cell_clipping;
-    bool                            _has_projection_weights;
-    bool                            _perform_projection_clipping;
-    bool                            _is_prepared;
-    bool                            _is_layer_norm_lstm;
+    MemoryGroup                    _memory_group;
+    CLFullyConnectedLayer          _fully_connected_input_gate;
+    CLArithmeticAddition           _accum_input_gate1;
+    CLArithmeticSubtraction        _subtract_input_gate;
+    CLPixelWiseMultiplication      _pixelwise_mul_input_gate;
+    CLActivationLayer              _activation_input_gate;
+    CLFullyConnectedLayer          _fully_connected_forget_gate;
+    CLArithmeticAddition           _accum_forget_gate1;
+    CLPixelWiseMultiplication      _pixelwise_mul_forget_gate;
+    CLActivationLayer              _activation_forget_gate;
+    CLFullyConnectedLayer          _fully_connected_cell_state;
+    CLGEMM                         _gemm_cell_state1;
+    CLTransposeKernel              _transpose_cell_state;
+    CLArithmeticAddition           _accum_cell_state1;
+    CLArithmeticAddition           _accum_cell_state2;
+    CLPixelWiseMultiplication      _pixelwise_mul_cell_state1;
+    CLActivationLayer              _activation_cell_state;
+    CLActivationLayer              _cell_clip;
+    CLPixelWiseMultiplication      _pixelwise_mul_cell_state2;
+    CLFullyConnectedLayer          _fully_connected_output;
+    CLPixelWiseMultiplication      _pixelwise_mul_output_state1;
+    CLArithmeticAddition           _accum_output1;
+    CLActivationLayer              _activation_output;
+    CLActivationLayer              _activation_output_state;
+    CLPixelWiseMultiplication      _pixelwise_mul_output_state2;
+    CLFullyConnectedLayer          _fully_connected_output_state;
+    CLActivationLayer              _projection_clip;
+    CLCopyKernel                   _copy_cell_state;
+    CLCopyKernel                   _copy_output;
+    CLConcatenateLayer             _concat_scratch_buffer;
+    CLConcatenateLayer             _concat_inputs_forget_gate;
+    CLConcatenateLayer             _concat_weights_forget_gate;
+    CLConcatenateLayer             _concat_weights_input_gate;
+    CLConcatenateLayer             _concat_weights_output;
+    CLMemsetKernel                 _ones_memset_kernel;
+    CLMeanStdDevNormalizationLayer _mean_std_norm_input_gate;
+    CLPixelWiseMultiplication      _pixelwise_mul_input_gate_coeff;
+    CLArithmeticAddition           _accum_input_gate_bias;
+    CLMeanStdDevNormalizationLayer _mean_std_norm_forget_gate;
+    CLPixelWiseMultiplication      _pixelwise_mul_forget_gate_coeff;
+    CLArithmeticAddition           _accum_forget_gate_bias;
+    CLMeanStdDevNormalizationLayer _mean_std_norm_cell_gate;
+    CLPixelWiseMultiplication      _pixelwise_mul_cell_gate_coeff;
+    CLArithmeticAddition           _accum_cell_gate_bias;
+    CLMeanStdDevNormalizationLayer _mean_std_norm_output_gate;
+    CLPixelWiseMultiplication      _pixelwise_mul_output_gate_coeff;
+    CLArithmeticAddition           _accum_output_gate_bias;
+    CLTensor                       _input_gate_out1;
+    CLTensor                       _input_gate_out2;
+    CLTensor                       _input_gate_out3;
+    CLTensor                       _input_gate_out4;
+    CLTensor                       _forget_gate_out1;
+    CLTensor                       _forget_gate_out2;
+    CLTensor                       _forget_gate_out3;
+    CLTensor                       _forget_gate_out4;
+    CLTensor                       _forget_gate_out5;
+    CLTensor                       _forget_gate_out6;
+    CLTensor                       _cell_state_out1;
+    CLTensor                       _cell_state_out2;
+    CLTensor                       _cell_state_out3;
+    CLTensor                       _cell_state_out4;
+    CLTensor                       _cell_state_out5;
+    CLTensor                       _output1;
+    CLTensor                       _output2;
+    CLTensor                       _output3;
+    CLTensor                       _output4;
+    CLTensor                       _cell_state_activation;
+    CLTensor                       _output_state1;
+    CLTensor                       _ones;
+    CLTensor                       _input_layer_norm_out1;
+    CLTensor                       _input_layer_norm_out2;
+    CLTensor                       _forget_layer_norm_out1;
+    CLTensor                       _forget_layer_norm_out2;
+    CLTensor                       _cell_layer_norm_out1;
+    CLTensor                       _cell_layer_norm_out2;
+    CLTensor                       _output_layer_norm_out1;
+    CLTensor                       _output_layer_norm_out2;
+    bool                           _run_peephole_opt;
+    bool                           _run_cifg_opt;
+    bool                           _perform_cell_clipping;
+    bool                           _has_projection_weights;
+    bool                           _perform_projection_clipping;
+    bool                           _is_prepared;
+    bool                           _is_layer_norm_lstm;
 };
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_CLLSTMLAYER_H */
diff --git a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
index b87daba..ca8d77e 100644
--- a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
+++ b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
@@ -24,18 +24,141 @@
 #ifndef ARM_COMPUTE_CLPIXELWISEMULTIPLICATION_H
 #define ARM_COMPUTE_CLPIXELWISEMULTIPLICATION_H
 
-#include "arm_compute/core/Types.h"
-#include "arm_compute/runtime/CL/ICLSimpleFunction.h"
+#include "arm_compute/core/CL/kernels/CLFillBorderKernel.h"
+#include "arm_compute/runtime/CL/ICLOperator.h"
+#include "arm_compute/runtime/IFunction.h"
 
 namespace arm_compute
 {
 // Forward declaration
 class ICLTensor;
 
+namespace experimental
+{
 /** Basic function to run @ref CLPixelWiseMultiplicationKernel. */
-class CLPixelWiseMultiplication : public ICLSimpleFunction
+class CLPixelWiseMultiplication : public ICLOperator
 {
 public:
+    /** Default Constructor */
+    CLPixelWiseMultiplication();
+    /** Initialise the kernel's inputs, output and convertion policy.
+     *
+     * Valid configurations (Input1,Input2) -> Output :
+     *
+     *   - (U8,U8)                         -> U8
+     *   - (U8,U8)                         -> S16
+     *   - (U8,S16)                        -> S16
+     *   - (S16,U8)                        -> S16
+     *   - (S16,S16)                       -> S16
+     *   - (F16,F16)                       -> F16
+     *   - (F32,F32)                       -> F32
+     *   - (QASYMM8,QASYMM8)               -> QASYMM8
+     *   - (QASYMM8_SIGNED,QASYMM8_SIGNED) -> QASYMM8_SIGNED
+     *   - (QSYMM16,QSYMM16)               -> QSYMM16
+     *   - (QSYMM16,QSYMM16)               -> S32
+     *
+     * @param[in]      compile_context The compile context to be used.
+     * @param[in, out] input1          An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     *                                 The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
+     * @param[in, out] input2          An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     *                                 The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
+     * @param[out]     output          The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in]      scale           Scale to apply after multiplication.
+     *                                 Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
+     * @param[in]      overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
+     * @param[in]      rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
+     * @param[in]      act_info        (Optional) Activation layer information in case of a fused activation.
+     */
+    void configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale,
+                   ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+    /** Static function to check if given info will lead to a valid configuration of @ref CLPixelWiseMultiplication
+     *
+     * Valid configurations (Input1,Input2) -> Output :
+     *
+     *   - (U8,U8)                         -> U8
+     *   - (U8,U8)                         -> S16
+     *   - (U8,S16)                        -> S16
+     *   - (S16,U8)                        -> S16
+     *   - (S16,S16)                       -> S16
+     *   - (F16,F16)                       -> F16
+     *   - (F32,F32)                       -> F32
+     *   - (QASYMM8,QASYMM8)               -> QASYMM8
+     *   - (QASYMM8_SIGNED,QASYMM8_SIGNED) -> QASYMM8_SIGNED
+     *   - (QSYMM16,QSYMM16)               -> QSYMM16
+     *   - (QSYMM16,QSYMM16)               -> S32
+     *
+     *
+     * @param[in] input1          An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in] input2          An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in] output          The output tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+     * @param[in] scale           Scale to apply after multiplication.
+     *                            Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
+     * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
+     * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
+     * @param[in] act_info        (Optional) Activation layer information in case of a fused activation.
+     *
+     * @return a status
+     */
+    static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
+                           ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+
+    // Inherited methods overridden:
+    void run(InputTensorMap inputs, OutputTensorMap outputs, OperatorTensorMap workspace) override;
+
+private:
+    CLFillBorderKernel _border_handler;
+};
+
+/** Basic function to run @ref CLComplexPixelWiseMultiplicationKernel. */
+class CLComplexPixelWiseMultiplication : public ICLOperator
+{
+public:
+    /** Default Constructor */
+    CLComplexPixelWiseMultiplication();
+    /** Initialise the kernel's inputs, output.
+     *
+     * @param[in]      compile_context The compile context to be used.
+     * @param[in, out] input1          An input tensor. Data types supported: F32. Number of channels supported: 2.
+     *                                 The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
+     * @param[in, out] input2          An input tensor. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     *                                 The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
+     * @param[out]     output          The output tensor, Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     * @param[in]      act_info        (Optional) Activation layer information in case of a fused activation.
+     */
+    void configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+    /** Static function to check if given info will lead to a valid configuration of @ref CLComplexPixelWiseMultiplication
+     *
+     * @param[in] input1   An input tensor info. Data types supported: F32. Number of channels supported: 2.
+     * @param[in] input2   An input tensor info. Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     * @param[in] output   The output tensor info, Data types supported: same as @p input1. Number of channels supported: same as @p input1.
+     * @param[in] act_info (Optional) Activation layer information in case of a fused activation.
+     */
+    static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+
+    // Inherited methods overridden:
+    void run(InputTensorMap inputs, OutputTensorMap outputs, OperatorTensorMap workspace) override;
+
+private:
+    CLFillBorderKernel _border_handler;
+};
+} // namespace experimental
+
+/** Basic function to run @ref CLPixelWiseMultiplicationKernel. */
+class CLPixelWiseMultiplication : public IFunction
+{
+public:
+    /** Default Constructor */
+    CLPixelWiseMultiplication();
+    /** Default Destructor */
+    ~CLPixelWiseMultiplication();
+    /** Prevent instances of this class from being copied (As this class contains pointers) */
+    CLPixelWiseMultiplication(const CLPixelWiseMultiplication &) = delete;
+    /** Default move constructor */
+    CLPixelWiseMultiplication(CLPixelWiseMultiplication &&);
+    /** Prevent instances of this class from being copied (As this class contains pointers) */
+    CLPixelWiseMultiplication &operator=(const CLPixelWiseMultiplication &) = delete;
+    /** Default move assignment operator */
+    CLPixelWiseMultiplication &operator=(CLPixelWiseMultiplication &&);
     /** Initialise the kernel's inputs, output and convertion policy.
      *
      * Valid configurations (Input1,Input2) -> Output :
@@ -125,12 +248,31 @@
      */
     static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
                            ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+
+    // Inherited methods overridden:
+    void run() override;
+
+private:
+    struct Impl;
+    std::unique_ptr<Impl> _impl;
 };
 
 /** Basic function to run @ref CLComplexPixelWiseMultiplicationKernel. */
-class CLComplexPixelWiseMultiplication : public ICLSimpleFunction
+class CLComplexPixelWiseMultiplication : public IFunction
 {
 public:
+    /** Default Constructor */
+    CLComplexPixelWiseMultiplication();
+    /** Default Destructor */
+    ~CLComplexPixelWiseMultiplication();
+    /** Prevent instances of this class from being copied (As this class contains pointers) */
+    CLComplexPixelWiseMultiplication(const CLComplexPixelWiseMultiplication &) = delete;
+    /** Default move constructor */
+    CLComplexPixelWiseMultiplication(CLComplexPixelWiseMultiplication &&);
+    /** Prevent instances of this class from being copied (As this class contains pointers) */
+    CLComplexPixelWiseMultiplication &operator=(const CLComplexPixelWiseMultiplication &) = delete;
+    /** Default move assignment operator */
+    CLComplexPixelWiseMultiplication &operator=(CLComplexPixelWiseMultiplication &&);
     /** Initialise the kernel's inputs, output.
      *
      * @param[in, out] input1   An input tensor. Data types supported: F32. Number of channels supported: 2.
@@ -160,6 +302,13 @@
      * @param[in] act_info (Optional) Activation layer information in case of a fused activation.
      */
     static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo());
+
+    // Inherited methods overridden:
+    void run() override;
+
+private:
+    struct Impl;
+    std::unique_ptr<Impl> _impl;
 };
 } // namespace arm_compute
 #endif /*ARM_COMPUTE_CLPIXELWISEMULTIPLICATION_H */
diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
index 0aea91a..53f337b 100644
--- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
@@ -26,13 +26,13 @@
 
 #include "arm_compute/core/CL/kernels/CLCopyKernel.h"
 #include "arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h"
-#include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
 #include "arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
 #include "arm_compute/runtime/CL/functions/CLElementwiseOperations.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
+#include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h"
 #include "arm_compute/runtime/CL/functions/CLTranspose.h"
 
 #include "arm_compute/runtime/common/LSTMParams.h"
@@ -52,7 +52,7 @@
  * -# @ref CLGEMMLowpMatrixMultiplyCore                          Quantized matrix multiplication core. Accumulators are 32-bit integers
  * -# @ref CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint   Convert 32-bit integers into QSYMM16
  * -# @ref CLGEMMLowpMatrixAReductionKernel                      For precomputing effective biases to use
- * -# @ref CLPixelWiseMultiplicationKernel                       Elementwise multiplication
+ * -# @ref CLPixelWiseMultiplication                       Elementwise multiplication
  * -# @ref CLTranspose                                           Transpose function for reshaping the weights
  * */
 class CLQLSTMLayer : public IFunction
@@ -113,7 +113,7 @@
                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                   const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
+                   ICLTensor *cell_state_in, const ICLTensor *output_state_in,
                    ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
                    const LSTMParams<ICLTensor> &lstm_params);
 
@@ -163,7 +163,7 @@
                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                   const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
+                   ICLTensor *cell_state_in, const ICLTensor *output_state_in,
                    ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
                    const LSTMParams<ICLTensor> &lstm_params);
 
@@ -306,7 +306,7 @@
     CLArithmeticAddition             _projection_bias_add{};
     CLGEMMLowpMatrixMultiplyCore     _mm_input_to_forget{};
     CLGEMMLowpMatrixMultiplyCore     _mm_recurrent_to_forget{};
-    CLPixelWiseMultiplicationKernel  _pixelwise_mul_cell_to_forget{};
+    CLPixelWiseMultiplication        _pixelwise_mul_cell_to_forget{};
     CLGEMMLowpOutputStage            _input_to_forget_outstage{};
     CLGEMMLowpOutputStage            _recurrent_to_forget_outstage{};
     CLGEMMLowpOutputStage            _cell_to_forget_outstage{};
@@ -325,12 +325,12 @@
     CLGEMMLowpMatrixMultiplyCore     _mm_recurrent_to_input{};
     CLGEMMLowpOutputStage            _recurrent_to_input_outstage{};
     CLArithmeticAddition             _accumulate_input_recurrent_input{};
-    CLPixelWiseMultiplicationKernel  _pixelwise_mul_cell_to_input{};
+    CLPixelWiseMultiplication        _pixelwise_mul_cell_to_input{};
     CLGEMMLowpOutputStage            _cell_to_input_outstage{};
     CLArithmeticAddition             _accumulate_cell_input{};
     CLActivationLayer                _input_gate_sigmoid{};
-    CLPixelWiseMultiplicationKernel  _pixelwise_mul_forget_cell{};
-    CLPixelWiseMultiplicationKernel  _pixelwise_mul_input_cell{};
+    CLPixelWiseMultiplication        _pixelwise_mul_forget_cell{};
+    CLPixelWiseMultiplication        _pixelwise_mul_input_cell{};
     CLArithmeticAddition             _add_forget_cell{};
     CLActivationLayer                _cell_clip{};
     CLGEMMLowpMatrixMultiplyCore     _mm_input_to_output{};
@@ -338,12 +338,12 @@
     CLGEMMLowpMatrixMultiplyCore     _mm_recurrent_to_output{};
     CLGEMMLowpOutputStage            _recurrent_to_output_outstage{};
     CLArithmeticAddition             _accumulate_input_recurrent_output{};
-    CLPixelWiseMultiplicationKernel  _pixelwise_mul_cell_to_output{};
+    CLPixelWiseMultiplication        _pixelwise_mul_cell_to_output{};
     CLGEMMLowpOutputStage            _cell_to_output_outstage{};
     CLArithmeticAddition             _accumulate_cell_to_output{};
     CLActivationLayer                _output_gate_sigmoid{};
     CLActivationLayer                _hidden_tanh{};
-    CLPixelWiseMultiplicationKernel  _pixelwise_mul_hidden{};
+    CLPixelWiseMultiplication        _pixelwise_mul_hidden{};
     CLGEMMLowpOutputStage            _hidden_outstage{};
     CLGEMMLowpMatrixMultiplyCore     _mm_projection{};
     CLGEMMLowpOutputStage            _projection_outstage{};
diff --git a/arm_compute/runtime/common/LSTMParams.h b/arm_compute/runtime/common/LSTMParams.h
index 82fca7e..ffb4ddd 100644
--- a/arm_compute/runtime/common/LSTMParams.h
+++ b/arm_compute/runtime/common/LSTMParams.h
@@ -81,7 +81,7 @@
      *
      * @return Reference to this LSTMParams object
      */
-    LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, const T *cell_to_input_weights, const T *input_gate_bias)
+    LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
     {
         _input_to_input_weights     = input_to_input_weights;
         _recurrent_to_input_weights = recurrent_to_input_weights;
@@ -111,7 +111,7 @@
      *
      * @return Reference to this LSTMParams object
      */
-    LSTMParams &set_peephole_params(const T *cell_to_forget_weights, const T *cell_to_output_weights)
+    LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
     {
         _cell_to_forget_weights = cell_to_forget_weights;
         _cell_to_output_weights = cell_to_output_weights;
@@ -127,8 +127,8 @@
      *
      * @return Reference to this LSTMParams object
      */
-    LSTMParams &set_layer_normalization_params(const T *input_layer_norm_weights, const T *forget_layer_norm_weights,
-                                               const T *cell_layer_norm_weights, const T *output_layer_norm_weights)
+    LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights,
+                                               T *cell_layer_norm_weights, T *output_layer_norm_weights)
     {
         _input_layer_norm_weights  = input_layer_norm_weights;
         _forget_layer_norm_weights = forget_layer_norm_weights;
@@ -204,7 +204,7 @@
         return _recurrent_to_input_weights;
     }
 
-    const T *cell_to_input_weights() const
+    T *cell_to_input_weights() const
     {
         return _cell_to_input_weights;
     }
@@ -214,12 +214,12 @@
         return _input_gate_bias;
     }
 
-    const T *cell_to_forget_weights() const
+    T *cell_to_forget_weights() const
     {
         return _cell_to_forget_weights;
     }
 
-    const T *cell_to_output_weights() const
+    T *cell_to_output_weights() const
     {
         return _cell_to_output_weights;
     }
@@ -234,22 +234,22 @@
         return _projection_bias;
     }
 
-    const T *input_layer_norm_weights() const
+    T *input_layer_norm_weights() const
     {
         return _input_layer_norm_weights;
     }
 
-    const T *forget_layer_norm_weights() const
+    T *forget_layer_norm_weights() const
     {
         return _forget_layer_norm_weights;
     }
 
-    const T *cell_layer_norm_weights() const
+    T *cell_layer_norm_weights() const
     {
         return _cell_layer_norm_weights;
     }
 
-    const T *output_layer_norm_weights() const
+    T *output_layer_norm_weights() const
     {
         return _output_layer_norm_weights;
     }
@@ -317,16 +317,16 @@
 private:
     const T *_input_to_input_weights;
     const T *_recurrent_to_input_weights;
-    const T *_cell_to_input_weights;
+    T       *_cell_to_input_weights;
     const T *_input_gate_bias;
-    const T *_cell_to_forget_weights;
-    const T *_cell_to_output_weights;
+    T       *_cell_to_forget_weights;
+    T       *_cell_to_output_weights;
     const T *_projection_weights;
     const T *_projection_bias;
-    const T *_input_layer_norm_weights;
-    const T *_forget_layer_norm_weights;
-    const T *_cell_layer_norm_weights;
-    const T *_output_layer_norm_weights;
+    T       *_input_layer_norm_weights;
+    T       *_forget_layer_norm_weights;
+    T       *_cell_layer_norm_weights;
+    T       *_output_layer_norm_weights;
     float    _cell_clip;
     float    _projection_clip;
     float    _input_intermediate_scale;
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
index 2289316..95869f7 100644
--- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
+++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
@@ -29,6 +29,7 @@
 #include "arm_compute/core/CL/ICLTensor.h"
 #include "arm_compute/core/CL/OpenCL.h"
 #include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/utils/misc/Cast.h"
 #include "support/StringSupport.h"
 
 namespace arm_compute
@@ -142,21 +143,21 @@
 {
 }
 
-void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
+void CLPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale,
                                                 ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
 {
     configure(CLKernelLibrary::get().get_compile_context(), input1, input2, output, scale, overflow_policy, rounding_policy, act_info);
 }
 
-void CLPixelWiseMultiplicationKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
+void CLPixelWiseMultiplicationKernel::configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale,
                                                 ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(),
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, output,
                                                   scale, overflow_policy, rounding_policy, act_info));
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
+    auto win_config = validate_and_configure_window(input1, input2, output);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
 
     _input1 = input1;
@@ -179,14 +180,14 @@
 
     std::string acc_type;
     // Check if it has float inputs and output
-    if(is_data_type_float(input1->info()->data_type()) || is_data_type_float(input2->info()->data_type()))
+    if(is_data_type_float(input1->data_type()) || is_data_type_float(input2->data_type()))
     {
         scale_int = -1;
-        acc_type  = (input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) ? "float" : "half";
+        acc_type  = (input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32) ? "float" : "half";
     }
     else
     {
-        if(input1->info()->element_size() == 2 || input2->info()->element_size() == 2)
+        if(input1->element_size() == 2 || input2->element_size() == 2)
         {
             // Use 32-bit accumulator for 16-bit input
             acc_type = "int";
@@ -198,26 +199,26 @@
         }
     }
 
-    const bool is_quantized = is_data_type_quantized(input1->info()->data_type());
+    const bool is_quantized = is_data_type_quantized(input1->data_type());
 
     // Set kernel build options
     std::string    kernel_name = "pixelwise_mul";
     CLBuildOptions build_opts;
-    build_opts.add_option("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(input1->info()->data_type()));
-    build_opts.add_option("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->info()->data_type()));
-    build_opts.add_option("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
+    build_opts.add_option("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(input1->data_type()));
+    build_opts.add_option("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->data_type()));
+    build_opts.add_option("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->data_type()));
     build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
-    if(is_quantized && (output->info()->data_type() != DataType::S32))
+    if(is_quantized && (output->data_type() != DataType::S32))
     {
-        const UniformQuantizationInfo iq1_info = input1->info()->quantization_info().uniform();
-        const UniformQuantizationInfo iq2_info = input2->info()->quantization_info().uniform();
-        const UniformQuantizationInfo oq_info  = output->info()->quantization_info().uniform();
+        const UniformQuantizationInfo iq1_info = input1->quantization_info().uniform();
+        const UniformQuantizationInfo iq2_info = input2->quantization_info().uniform();
+        const UniformQuantizationInfo oq_info  = output->quantization_info().uniform();
 
-        build_opts.add_option_if(is_data_type_quantized_asymmetric(input1->info()->data_type()),
+        build_opts.add_option_if(is_data_type_quantized_asymmetric(input1->data_type()),
                                  "-DOFFSET_IN1=" + support::cpp11::to_string(iq1_info.offset));
-        build_opts.add_option_if(is_data_type_quantized_asymmetric(input2->info()->data_type()),
+        build_opts.add_option_if(is_data_type_quantized_asymmetric(input2->data_type()),
                                  "-DOFFSET_IN2=" + support::cpp11::to_string(iq2_info.offset));
-        build_opts.add_option_if(is_data_type_quantized_asymmetric(output->info()->data_type()),
+        build_opts.add_option_if(is_data_type_quantized_asymmetric(output->data_type()),
                                  "-DOFFSET_OUT=" + support::cpp11::to_string(oq_info.offset));
         build_opts.add_option("-DSCALE_IN1=" + float_to_string_with_full_precision(iq1_info.scale));
         build_opts.add_option("-DSCALE_IN2=" + float_to_string_with_full_precision(iq2_info.scale));
@@ -227,7 +228,7 @@
     else
     {
         kernel_name += (scale_int >= 0) ? "_int" : "_float";
-        build_opts.add_option_if_else(overflow_policy == ConvertPolicy::WRAP || is_data_type_float(output->info()->data_type()), "-DWRAP", "-DSATURATE");
+        build_opts.add_option_if_else(overflow_policy == ConvertPolicy::WRAP || is_data_type_float(output->data_type()), "-DWRAP", "-DSATURATE");
         build_opts.add_option_if_else(rounding_policy == RoundingPolicy::TO_ZERO, "-DROUND=_rtz", "-DROUND=_rte");
         build_opts.add_option("-DACC_DATA_TYPE=" + acc_type);
         if(act_info.enabled())
@@ -266,14 +267,18 @@
     return Status{};
 }
 
-void CLPixelWiseMultiplicationKernel::run(const Window &window, cl::CommandQueue &queue)
+void CLPixelWiseMultiplicationKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, cl::CommandQueue &queue)
 {
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
 
-    const TensorShape &in_shape1 = _input1->info()->tensor_shape();
-    const TensorShape &in_shape2 = _input2->info()->tensor_shape();
-    const TensorShape &out_shape = _output->info()->tensor_shape();
+    const auto src_0 = utils::cast::polymorphic_downcast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC_0));
+    const auto src_1 = utils::cast::polymorphic_downcast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC_1));
+    auto       dst   = utils::cast::polymorphic_downcast<ICLTensor *>(outputs.at(TensorType::ACL_DST));
+
+    const TensorShape &in_shape1 = src_0->info()->tensor_shape();
+    const TensorShape &in_shape2 = src_1->info()->tensor_shape();
+    const TensorShape &out_shape = dst->info()->tensor_shape();
 
     bool can_collapse = true;
     if(std::min(in_shape1.total_size(), in_shape2.total_size()) > 1)
@@ -298,9 +303,9 @@
     do
     {
         unsigned int idx = 0;
-        add_3D_tensor_argument(idx, _input1, slice_input1);
-        add_3D_tensor_argument(idx, _input2, slice_input2);
-        add_3D_tensor_argument(idx, _output, slice);
+        add_3D_tensor_argument(idx, src_0, slice_input1);
+        add_3D_tensor_argument(idx, src_1, slice_input2);
+        add_3D_tensor_argument(idx, dst, slice);
         enqueue(queue, *this, slice, lws_hint());
 
         ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
@@ -311,7 +316,7 @@
 
 BorderSize CLPixelWiseMultiplicationKernel::border_size() const
 {
-    const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
+    const unsigned int replicateSize = _output->dimension(0) - std::min(_input1->dimension(0), _input2->dimension(0));
     const unsigned int border        = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
     return BorderSize{ 0, border, 0, 0 };
 }
@@ -374,18 +379,18 @@
 {
 }
 
-void CLComplexPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info)
+void CLComplexPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const ActivationLayerInfo &act_info)
 {
     configure(CLKernelLibrary::get().get_compile_context(), input1, input2, output, act_info);
 }
 
-void CLComplexPixelWiseMultiplicationKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info)
+void CLComplexPixelWiseMultiplicationKernel::configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const ActivationLayerInfo &act_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1->info(), input2->info(), output->info(), act_info));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1, input2, output, act_info));
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window_complex(input1->info(), input2->info(), output->info());
+    auto win_config = validate_and_configure_window_complex(input1, input2, output);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
 
     _input1 = input1;
@@ -415,14 +420,18 @@
     return Status{};
 }
 
-void CLComplexPixelWiseMultiplicationKernel::run(const Window &window, cl::CommandQueue &queue)
+void CLComplexPixelWiseMultiplicationKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, cl::CommandQueue &queue)
 {
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
 
-    const TensorShape &in_shape1 = _input1->info()->tensor_shape();
-    const TensorShape &in_shape2 = _input2->info()->tensor_shape();
-    const TensorShape &out_shape = _output->info()->tensor_shape();
+    const auto src_0 = utils::cast::polymorphic_downcast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC_0));
+    const auto src_1 = utils::cast::polymorphic_downcast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC_1));
+    auto       dst   = utils::cast::polymorphic_downcast<ICLTensor *>(outputs.at(TensorType::ACL_DST));
+
+    const TensorShape &in_shape1 = src_0->info()->tensor_shape();
+    const TensorShape &in_shape2 = src_1->info()->tensor_shape();
+    const TensorShape &out_shape = dst->info()->tensor_shape();
 
     bool can_collapse = true;
     if(std::min(in_shape1.total_size(), in_shape2.total_size()) > 1)
@@ -447,9 +456,9 @@
     do
     {
         unsigned int idx = 0;
-        add_3D_tensor_argument(idx, _input1, slice_input1);
-        add_3D_tensor_argument(idx, _input2, slice_input2);
-        add_3D_tensor_argument(idx, _output, slice);
+        add_3D_tensor_argument(idx, src_0, slice_input1);
+        add_3D_tensor_argument(idx, src_1, slice_input2);
+        add_3D_tensor_argument(idx, dst, slice);
         enqueue(queue, *this, slice, lws_hint());
 
         ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
@@ -460,7 +469,7 @@
 
 BorderSize CLComplexPixelWiseMultiplicationKernel::border_size() const
 {
-    const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
+    const unsigned int replicateSize = _output->dimension(0) - std::min(_input1->dimension(0), _input2->dimension(0));
     const unsigned int border        = std::min<unsigned int>(num_elems_processed_per_iteration_complex - 1U, replicateSize);
     return BorderSize{ 0, border, 0, 0 };
 }
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index a1c4124..058b602 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -55,7 +55,7 @@
                             const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                             const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                             const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                            const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
+                            const ICLTensor *output_state_in, ICLTensor *cell_state_in,
                             ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
                             const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
 {
@@ -68,7 +68,7 @@
                             const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                             const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                             const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                            const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
+                            const ICLTensor *output_state_in, ICLTensor *cell_state_in,
                             ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
                             const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
 {
@@ -489,14 +489,14 @@
 
     if(lstm_params.has_peephole_opt())
     {
-        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
     }
     if(lstm_params.use_layer_norm())
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
-                                                                              RoundingPolicy::TO_NEAREST_EVEN));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
+                                                                        RoundingPolicy::TO_NEAREST_EVEN));
         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
     }
     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
@@ -524,14 +524,14 @@
         {
             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
-            ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
         }
 
         if(lstm_params.use_layer_norm())
         {
             ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
-            ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
         }
         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
@@ -548,13 +548,13 @@
     if(lstm_params.use_layer_norm())
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
-                                                                              RoundingPolicy::TO_NEAREST_EVEN));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
+                                                                        RoundingPolicy::TO_NEAREST_EVEN));
         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
     }
     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
-    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
-    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
     if(cell_threshold != 0.f)
     {
@@ -573,22 +573,22 @@
 
     if(lstm_params.has_peephole_opt())
     {
-        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
-                                                                              RoundingPolicy::TO_NEAREST_EVEN));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
+                                                                        RoundingPolicy::TO_NEAREST_EVEN));
         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
     }
     if(lstm_params.use_layer_norm())
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
-                                                                              RoundingPolicy::TO_NEAREST_EVEN));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
+                                                                        RoundingPolicy::TO_NEAREST_EVEN));
         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
     }
     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
 
     // Validate output state
     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
-    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
     if(lstm_params.has_projection())
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
@@ -629,13 +629,13 @@
 
     if(_run_peephole_opt)
     {
-        CLScheduler::get().enqueue(_pixelwise_mul_forget_gate);
+        _pixelwise_mul_forget_gate.run();
         _accum_forget_gate1.run();
     }
     if(_is_layer_norm_lstm)
     {
         _mean_std_norm_forget_gate.run();
-        CLScheduler::get().enqueue(_pixelwise_mul_forget_gate_coeff);
+        _pixelwise_mul_forget_gate_coeff.run();
         _accum_forget_gate_bias.run();
     }
     _activation_forget_gate.run();
@@ -651,14 +651,14 @@
 
         if(_run_peephole_opt)
         {
-            CLScheduler::get().enqueue(_pixelwise_mul_input_gate);
+            _pixelwise_mul_input_gate.run();
             _accum_input_gate1.run();
         }
 
         if(_is_layer_norm_lstm)
         {
             _mean_std_norm_input_gate.run();
-            CLScheduler::get().enqueue(_pixelwise_mul_input_gate_coeff);
+            _pixelwise_mul_input_gate_coeff.run();
             _accum_input_gate_bias.run();
         }
         _activation_input_gate.run();
@@ -671,12 +671,12 @@
     if(_is_layer_norm_lstm)
     {
         _mean_std_norm_cell_gate.run();
-        CLScheduler::get().enqueue(_pixelwise_mul_cell_gate_coeff);
+        _pixelwise_mul_cell_gate_coeff.run();
         _accum_cell_gate_bias.run();
     }
     _activation_cell_state.run();
-    CLScheduler::get().enqueue(_pixelwise_mul_cell_state1);
-    CLScheduler::get().enqueue(_pixelwise_mul_cell_state2);
+    _pixelwise_mul_cell_state1.run();
+    _pixelwise_mul_cell_state2.run();
     _accum_cell_state2.run();
 
     if(_perform_cell_clipping)
@@ -688,19 +688,19 @@
 
     if(_run_peephole_opt)
     {
-        CLScheduler::get().enqueue(_pixelwise_mul_output_state1);
+        _pixelwise_mul_output_state1.run();
         _accum_output1.run();
     }
     if(_is_layer_norm_lstm)
     {
         _mean_std_norm_output_gate.run();
-        CLScheduler::get().enqueue(_pixelwise_mul_output_gate_coeff);
+        _pixelwise_mul_output_gate_coeff.run();
         _accum_output_gate_bias.run();
     }
     _activation_output.run();
 
     _activation_output_state.run();
-    CLScheduler::get().enqueue(_pixelwise_mul_output_state2);
+    _pixelwise_mul_output_state2.run();
 
     if(_has_projection_weights)
     {
diff --git a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
index 139209c..34e06a3 100644
--- a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
+++ b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
@@ -25,30 +25,50 @@
 
 #include "arm_compute/core/CL/ICLTensor.h"
 #include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
 #include "support/MemorySupport.h"
 
 #include <utility>
 
 namespace arm_compute
 {
-void CLPixelWiseMultiplication::configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale,
-                                          ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
+namespace
 {
-    configure(CLKernelLibrary::get().get_compile_context(), input1, input2, output, scale, overflow_policy, rounding_policy, act_info);
+void select_border_input(InputTensorMap &tensor_map, InputTensorMap &inputs, OutputTensorMap &outputs)
+{
+    if(outputs.at(TensorType::ACL_DST)->info()->dimension(0) > 1)
+    {
+        if(inputs.at(TensorType::ACL_SRC_1)->info()->dimension(0) == 1)
+        {
+            tensor_map[TensorType::ACL_SRC] = inputs.at(TensorType::ACL_SRC_1);
+        }
+        else
+        {
+            tensor_map[TensorType::ACL_SRC] = inputs.at(TensorType::ACL_SRC_0);
+        }
+    }
+}
+} // namespace
+
+namespace experimental
+{
+CLPixelWiseMultiplication::CLPixelWiseMultiplication()
+    : _border_handler()
+{
 }
 
-void CLPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale,
+void CLPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale,
                                           ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
 {
     auto k = arm_compute::support::cpp14::make_unique<CLPixelWiseMultiplicationKernel>();
     k->configure(compile_context, input1, input2, output, scale, overflow_policy, rounding_policy, act_info);
     _kernel = std::move(k);
 
-    if(output->info()->dimension(0) > 1)
+    if(output->dimension(0) > 1)
     {
-        ICLTensor *broadcasted_info = (input1->info()->dimension(0) == 1) ? input1 : input2;
+        ITensorInfo *broadcasted_info = (input1->dimension(0) == 1) ? input1 : input2;
 
-        if(broadcasted_info->info()->dimension(0) == 1)
+        if(broadcasted_info->dimension(0) == 1)
         {
             _border_handler.configure(compile_context, broadcasted_info, _kernel->border_size(), BorderMode::REPLICATE);
         }
@@ -61,22 +81,30 @@
     return CLPixelWiseMultiplicationKernel::validate(input1, input2, output, scale, overflow_policy, rounding_policy, act_info);
 }
 
-void CLComplexPixelWiseMultiplication::configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info)
+void CLPixelWiseMultiplication::run(InputTensorMap inputs, OutputTensorMap outputs, OperatorTensorMap workspace)
 {
-    configure(CLKernelLibrary::get().get_compile_context(), input1, input2, output, act_info);
+    InputTensorMap src;
+    select_border_input(src, inputs, outputs);
+    CLScheduler::get().enqueue_op(_border_handler, src, {});
+    ICLOperator::run(inputs, outputs, workspace);
 }
 
-void CLComplexPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ICLTensor *input1, ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info)
+CLComplexPixelWiseMultiplication::CLComplexPixelWiseMultiplication()
+    : _border_handler()
+{
+}
+
+void CLComplexPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const ActivationLayerInfo &act_info)
 {
     auto k = arm_compute::support::cpp14::make_unique<CLComplexPixelWiseMultiplicationKernel>();
     k->configure(compile_context, input1, input2, output, act_info);
     _kernel = std::move(k);
 
-    if(output->info()->dimension(0) > 1)
+    if(output->dimension(0) > 1)
     {
-        ICLTensor *broadcasted_info = (input1->info()->dimension(0) == 1) ? input1 : input2;
+        ITensorInfo *broadcasted_info = (input1->dimension(0) == 1) ? input1 : input2;
 
-        if(broadcasted_info->info()->dimension(0) == 1)
+        if(broadcasted_info->dimension(0) == 1)
         {
             _border_handler.configure(compile_context, broadcasted_info, _kernel->border_size(), BorderMode::REPLICATE);
         }
@@ -87,4 +115,102 @@
 {
     return CLComplexPixelWiseMultiplicationKernel::validate(input1, input2, output, act_info);
 }
+
+void CLComplexPixelWiseMultiplication::run(InputTensorMap inputs, OutputTensorMap outputs, OperatorTensorMap workspace)
+{
+    InputTensorMap src;
+    select_border_input(src, inputs, outputs);
+    CLScheduler::get().enqueue_op(_border_handler, src, {});
+    ICLOperator::run(inputs, outputs, workspace);
+}
+} // namespace experimental
+
+struct CLPixelWiseMultiplication::Impl
+{
+    const ICLTensor                                         *src_0{ nullptr };
+    const ICLTensor                                         *src_1{ nullptr };
+    ICLTensor                                               *dst{ nullptr };
+    std::unique_ptr<experimental::CLPixelWiseMultiplication> op{ nullptr };
+};
+
+CLPixelWiseMultiplication::CLPixelWiseMultiplication()
+    : _impl(support::cpp14::make_unique<Impl>())
+{
+}
+CLPixelWiseMultiplication::CLPixelWiseMultiplication(CLPixelWiseMultiplication &&) = default;
+CLPixelWiseMultiplication &CLPixelWiseMultiplication::operator=(CLPixelWiseMultiplication &&) = default;
+CLPixelWiseMultiplication::~CLPixelWiseMultiplication()                                       = default;
+
+void CLPixelWiseMultiplication::configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale,
+                                          ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
+{
+    configure(CLKernelLibrary::get().get_compile_context(), input1, input2, output, scale, overflow_policy, rounding_policy, act_info);
+}
+
+void CLPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale,
+                                          ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
+{
+    _impl->src_0 = input1;
+    _impl->src_1 = input2;
+    _impl->dst   = output;
+    _impl->op    = arm_compute::support::cpp14::make_unique<experimental::CLPixelWiseMultiplication>();
+    _impl->op->configure(compile_context, input1->info(), input2->info(), output->info(), scale, overflow_policy, rounding_policy, act_info);
+}
+
+Status CLPixelWiseMultiplication::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
+                                           ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
+{
+    return experimental::CLPixelWiseMultiplication::validate(input1, input2, output, scale, overflow_policy, rounding_policy, act_info);
+}
+
+void CLPixelWiseMultiplication::run()
+{
+    const InputTensorMap  src{ { TensorType::ACL_SRC_0, _impl->src_0 }, { TensorType::ACL_SRC_1, _impl->src_1 } };
+    const OutputTensorMap dst{ { TensorType::ACL_DST, _impl->dst } };
+
+    _impl->op->run(src, dst, {});
+}
+
+struct CLComplexPixelWiseMultiplication::Impl
+{
+    const ICLTensor                                                *src_0{ nullptr };
+    const ICLTensor                                                *src_1{ nullptr };
+    ICLTensor                                                      *dst{ nullptr };
+    std::unique_ptr<experimental::CLComplexPixelWiseMultiplication> op{ nullptr };
+};
+
+CLComplexPixelWiseMultiplication::CLComplexPixelWiseMultiplication()
+    : _impl(support::cpp14::make_unique<Impl>())
+{
+}
+CLComplexPixelWiseMultiplication::CLComplexPixelWiseMultiplication(CLComplexPixelWiseMultiplication &&) = default;
+CLComplexPixelWiseMultiplication &CLComplexPixelWiseMultiplication::operator=(CLComplexPixelWiseMultiplication &&) = default;
+CLComplexPixelWiseMultiplication::~CLComplexPixelWiseMultiplication()                                              = default;
+
+void CLComplexPixelWiseMultiplication::configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info)
+{
+    configure(CLKernelLibrary::get().get_compile_context(), input1, input2, output, act_info);
+}
+
+void CLComplexPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ICLTensor *input1, ICLTensor *input2, ICLTensor *output, const ActivationLayerInfo &act_info)
+{
+    _impl->src_0 = input1;
+    _impl->src_1 = input2;
+    _impl->dst   = output;
+    _impl->op    = arm_compute::support::cpp14::make_unique<experimental::CLComplexPixelWiseMultiplication>();
+    _impl->op->configure(compile_context, input1->info(), input2->info(), output->info(), act_info);
+}
+
+Status CLComplexPixelWiseMultiplication::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, const ActivationLayerInfo &act_info)
+{
+    return experimental::CLComplexPixelWiseMultiplication::validate(input1, input2, output, act_info);
+}
+
+void CLComplexPixelWiseMultiplication::run()
+{
+    const InputTensorMap  src{ { TensorType::ACL_SRC_0, _impl->src_0 }, { TensorType::ACL_SRC_1, _impl->src_1 } };
+    const OutputTensorMap dst{ { TensorType::ACL_DST, _impl->dst } };
+
+    _impl->op->run(src, dst, {});
+}
 } // namespace arm_compute
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index c5c4aa3..a40a5d0 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -113,7 +113,7 @@
                              const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                              const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                              const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                             const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
+                             ICLTensor *cell_state_in, const ICLTensor *output_state_in,
                              ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
                              const LSTMParams<ICLTensor> &lstm_params)
 {
@@ -126,7 +126,7 @@
                              const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
                              const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
                              const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
-                             const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
+                             ICLTensor *cell_state_in, const ICLTensor *output_state_in,
                              ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
                              const LSTMParams<ICLTensor> &lstm_params)
 {
@@ -382,7 +382,7 @@
         input_activation_input->allocator()->allocate();
     }
     // Cell.
-    // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
+    // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
     _pixelwise_mul_forget_cell.configure(compile_context, &_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
     const float      cell_gate_scale      = _cell_gate.info()->quantization_info().uniform().scale;
     const float      mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
@@ -418,7 +418,7 @@
 
     if(_has_peephole)
     {
-        // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
+        // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
         // Here we are not using the output stage because all operations are done in float
         _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
         _memory_group.manage(&_mul_cell_to_output_res);
@@ -453,7 +453,7 @@
 
     // Hidden.
     _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
-    // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
+    // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
     _memory_group.manage(&_hidden_mul_res);
     const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
     _hidden_mul_res.allocator()->init(hidden_mul_res);
@@ -696,8 +696,8 @@
     if(lstm_params.has_peephole_opt())
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
-        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
-                                                                              RoundingPolicy::TO_ZERO));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
+                                                                        RoundingPolicy::TO_ZERO));
         const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
@@ -766,8 +766,8 @@
 
         if(lstm_params.has_peephole_opt())
         {
-            ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
-                                                                                  RoundingPolicy::TO_ZERO));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
+                                                                            RoundingPolicy::TO_ZERO));
             const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
             ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
             ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
@@ -784,8 +784,8 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f)));
     }
     // Cell.
-    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
-    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
     if(quantized_cell_clip > 0)
     {
@@ -809,8 +809,8 @@
         // Here we are not using the output stage because all operations are done in float
         // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
         // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
-                                                                              RoundingPolicy::TO_ZERO));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
+                                                                        RoundingPolicy::TO_ZERO));
         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
     }
 
@@ -830,7 +830,7 @@
     const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
 
     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
-    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
     const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
     ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
     gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
@@ -926,7 +926,7 @@
 
     if(_has_peephole)
     {
-        CLScheduler::get().enqueue(_pixelwise_mul_cell_to_forget);
+        _pixelwise_mul_cell_to_forget.run();
         _cell_to_forget_outstage.run();
         _accumulate_cell_forget.run();
     }
@@ -968,7 +968,7 @@
 
         if(_has_peephole)
         {
-            CLScheduler::get().enqueue(_pixelwise_mul_cell_to_input);
+            _pixelwise_mul_cell_to_input.run();
             _cell_to_input_outstage.run();
             _accumulate_cell_input.run();
         }
@@ -982,8 +982,8 @@
     }
 
     // Cell.
-    CLScheduler::get().enqueue(_pixelwise_mul_forget_cell);
-    CLScheduler::get().enqueue(_pixelwise_mul_input_cell);
+    _pixelwise_mul_forget_cell.run();
+    _pixelwise_mul_input_cell.run();
     _add_forget_cell.run();
     if(_has_cell_clipping)
     {
@@ -998,7 +998,7 @@
     _accumulate_input_recurrent_output.run();
     if(_has_peephole)
     {
-        CLScheduler::get().enqueue(_pixelwise_mul_cell_to_output);
+        _pixelwise_mul_cell_to_output.run();
         _cell_to_output_outstage.run();
         _accumulate_cell_to_output.run();
     }
@@ -1012,7 +1012,7 @@
 
     // Hidden.
     _hidden_tanh.run();
-    CLScheduler::get().enqueue(_pixelwise_mul_hidden);
+    _pixelwise_mul_hidden.run();
     _hidden_outstage.run();
 
     // Projection.
diff --git a/tests/validation/CL/LSTMLayer.cpp b/tests/validation/CL/LSTMLayer.cpp
index 8e8ff8d..a550613 100644
--- a/tests/validation/CL/LSTMLayer.cpp
+++ b/tests/validation/CL/LSTMLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -134,9 +134,10 @@
                input_info, input_weights_info, recurrent_weights_info, cell_bias_info, projection_bias_info, cell_state_info, output_info, scratch_info, info, expected)
 {
     LSTMParams<ITensorInfo> lstm_params_info;
-    lstm_params_info.set_peephole_params(&cell_bias_info, &cell_bias_info)
+    auto cell_bias_clone = cell_bias_info.clone();
+    lstm_params_info.set_peephole_params(cell_bias_clone.get(), cell_bias_clone.get())
                     .set_projection_params(&recurrent_weights_info, &projection_bias_info)
-                    .set_cifg_params(&input_weights_info, &recurrent_weights_info, &cell_bias_info, &cell_bias_info);
+                    .set_cifg_params(&input_weights_info, &recurrent_weights_info, cell_bias_clone.get(), cell_bias_clone.get());
 
     ARM_COMPUTE_EXPECT(bool(CLLSTMLayer::validate(&input_info.clone()->set_is_resizable(false), &input_weights_info.clone()->set_is_resizable(false), &input_weights_info.clone()->set_is_resizable(false),
                                                   &input_weights_info.clone()->set_is_resizable(false), &recurrent_weights_info.clone()->set_is_resizable(false), &recurrent_weights_info.clone()->set_is_resizable(false),
diff --git a/tests/validation/NEON/LSTMLayer.cpp b/tests/validation/NEON/LSTMLayer.cpp
index a14496a..0850dc6 100644
--- a/tests/validation/NEON/LSTMLayer.cpp
+++ b/tests/validation/NEON/LSTMLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -134,9 +134,10 @@
                input_info, input_weights_info, recurrent_weights_info, cell_bias_info, projection_bias_info, cell_state_info, output_info, scratch_info, info, expected)
 {
     LSTMParams<ITensorInfo> lstm_params_info;
-    lstm_params_info.set_peephole_params(&cell_bias_info, &cell_bias_info)
+    auto cell_bias_clone = cell_bias_info.clone();
+    lstm_params_info.set_peephole_params(cell_bias_clone.get(), cell_bias_clone.get())
                     .set_projection_params(&recurrent_weights_info, &projection_bias_info)
-                    .set_cifg_params(&input_weights_info, &recurrent_weights_info, &cell_bias_info, &cell_bias_info);
+                    .set_cifg_params(&input_weights_info, &recurrent_weights_info, cell_bias_clone.get(), cell_bias_clone.get());
 
     ARM_COMPUTE_EXPECT(bool(NELSTMLayer::validate(&input_info.clone()->set_is_resizable(false), &input_weights_info.clone()->set_is_resizable(false), &input_weights_info.clone()->set_is_resizable(false),
                                                   &input_weights_info.clone()->set_is_resizable(false), &recurrent_weights_info.clone()->set_is_resizable(false), &recurrent_weights_info.clone()->set_is_resizable(false),