COMPMID-3385: Async support to CLArithmetic* kernels/functions Pt.1

Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Change-Id: I94007565e688f8a0aead4f14c9fc30bfd9f9f7eb
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3613
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index 8c45a98..c5c4aa3 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -213,7 +213,7 @@
         _projection_reduction.configure(compile_context, _projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
         if(_projection_bias != nullptr)
         {
-            _projection_bias_add.configure(compile_context, ArithmeticOperation::ADD, _projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
+            _projection_bias_add.configure(compile_context, _projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
         }
     }
 
@@ -255,7 +255,7 @@
                  &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
                  mm_out_info, forget_gate_outstage_info);
 
-    _accumulate_input_recurrent_forget.configure(compile_context, ArithmeticOperation::ADD, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
+    _accumulate_input_recurrent_forget.configure(compile_context, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
                                                  ConvertPolicy::SATURATE);
     _input_to_forget_outstage_res.allocator()->allocate();
 
@@ -270,7 +270,7 @@
         quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
         _cell_to_forget_outstage.configure(compile_context, &_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
         _mul_cell_to_forget_res.allocator()->allocate();
-        _accumulate_cell_forget.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
+        _accumulate_cell_forget.configure(compile_context, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
                                           ConvertPolicy::SATURATE);
         _cell_to_forget_outstage_res.allocator()->allocate();
     }
@@ -307,7 +307,7 @@
                  &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
                  mm_out_info, cell_outstage_info);
 
-    _accumulate_input_recurrent_modulation.configure(compile_context, ArithmeticOperation::ADD, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
+    _accumulate_input_recurrent_modulation.configure(compile_context, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
                                                      ConvertPolicy::SATURATE);
     _input_to_cell_outstage_res.allocator()->allocate();
 
@@ -333,7 +333,7 @@
     if(_has_cifg)
     {
         _ones.allocator()->init(*_forget_gate.info());
-        _input_gate_sub.configure(compile_context, ArithmeticOperation::SUB, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
+        _input_gate_sub.configure(compile_context, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
         _ones.allocator()->allocate();
     }
     else
@@ -350,7 +350,7 @@
                      output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
                      &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
                      mm_out_info, input_outstage_info);
-        _accumulate_input_recurrent_input.configure(compile_context, ArithmeticOperation::ADD, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
+        _accumulate_input_recurrent_input.configure(compile_context, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
                                                     ConvertPolicy::SATURATE);
         _input_to_input_outstage_res.allocator()->allocate();
 
@@ -365,7 +365,7 @@
             _memory_group.manage(&_cell_to_input_outstage_res);
             _cell_to_input_outstage.configure(compile_context, &_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
             _mul_cell_to_input_res.allocator()->allocate();
-            _accumulate_cell_input.configure(ArithmeticOperation::ADD, &_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
+            _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
             _cell_to_input_outstage_res.allocator()->allocate();
         }
 
@@ -391,7 +391,7 @@
     _mul_input_cell_res.allocator()->init(mul_input_cell_info);
     _pixelwise_mul_input_cell.configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
     _cell_gate.allocator()->allocate();
-    _add_forget_cell.configure(compile_context, ArithmeticOperation::ADD, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
+    _add_forget_cell.configure(compile_context, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
     _mul_input_cell_res.allocator()->allocate();
     _forget_gate.allocator()->allocate();
     if(_has_cell_clipping)
@@ -412,7 +412,7 @@
                  &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
                  mm_out_info, output_outstage_info);
 
-    _accumulate_input_recurrent_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
+    _accumulate_input_recurrent_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
                                                  ConvertPolicy::SATURATE);
     _input_to_output_outstage_res.allocator()->allocate();
 
@@ -431,7 +431,7 @@
         _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
         _mul_cell_to_output_res.allocator()->allocate();
 
-        _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
+        _accumulate_cell_to_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
                                              ConvertPolicy::SATURATE);
         _cell_to_output_outstage_res.allocator()->allocate();
     }
@@ -510,7 +510,7 @@
             accumulate_destination = &_projection_accumulate_res;
         }
 
-        _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
+        _accumulate_projection.configure(compile_context, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
         _projection_outstage_res.allocator()->allocate();
 
         if(_projection_tensor_copy_required)
@@ -647,8 +647,8 @@
         if(lstm_params.projection_bias() != nullptr)
         {
             ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
-            ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, lstm_params.projection_bias(), &projection_eff_bias_info,
-                                                                                       &projection_eff_bias_info, ConvertPolicy::SATURATE));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info,
+                                                                       &projection_eff_bias_info, ConvertPolicy::SATURATE));
         }
     }
 
@@ -691,7 +691,7 @@
     const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
 
-    ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
 
     if(lstm_params.has_peephole_opt())
     {
@@ -701,7 +701,7 @@
         const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
     }
 
     if(has_layer_norm)
@@ -726,7 +726,7 @@
     const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
 
-    ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
 
     if(has_layer_norm)
     {
@@ -743,7 +743,7 @@
     if(lstm_params.has_cifg_opt())
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
-        ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::SUB, &input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
     }
     else
     {
@@ -762,7 +762,7 @@
         const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
         ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
 
-        ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
 
         if(lstm_params.has_peephole_opt())
         {
@@ -771,7 +771,7 @@
             const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
             ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
             ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
-            ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
         }
 
         if(has_layer_norm)
@@ -786,7 +786,7 @@
     // Cell.
     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
-    ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
     if(quantized_cell_clip > 0)
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
@@ -801,7 +801,7 @@
     const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
 
-    ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
     if(lstm_params.has_peephole_opt())
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
@@ -811,7 +811,7 @@
         // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
                                                                               RoundingPolicy::TO_ZERO));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
     }
 
     if(has_layer_norm)
@@ -866,7 +866,7 @@
             ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
         }
 
-        ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
 
         if(projection_tensor_copy_required)
         {
@@ -922,13 +922,13 @@
 
     _mm_recurrent_to_forget.run();
     _recurrent_to_forget_outstage.run();
-    CLScheduler::get().enqueue(_accumulate_input_recurrent_forget);
+    _accumulate_input_recurrent_forget.run();
 
     if(_has_peephole)
     {
         CLScheduler::get().enqueue(_pixelwise_mul_cell_to_forget);
         _cell_to_forget_outstage.run();
-        CLScheduler::get().enqueue(_accumulate_cell_forget);
+        _accumulate_cell_forget.run();
     }
 
     if(_has_layer_norm)
@@ -944,7 +944,7 @@
 
     _mm_recurrent_to_cell.run();
     _recurrent_to_cell_outstage.run();
-    CLScheduler::get().enqueue(_accumulate_input_recurrent_modulation);
+    _accumulate_input_recurrent_modulation.run();
 
     if(_has_layer_norm)
     {
@@ -956,7 +956,7 @@
     // Input gate
     if(_has_cifg)
     {
-        CLScheduler::get().enqueue(_input_gate_sub);
+        _input_gate_sub.run();
     }
     else
     {
@@ -964,13 +964,13 @@
         _input_to_input_outstage.run();
         _mm_recurrent_to_input.run();
         _recurrent_to_input_outstage.run();
-        CLScheduler::get().enqueue(_accumulate_input_recurrent_input);
+        _accumulate_input_recurrent_input.run();
 
         if(_has_peephole)
         {
             CLScheduler::get().enqueue(_pixelwise_mul_cell_to_input);
             _cell_to_input_outstage.run();
-            CLScheduler::get().enqueue(_accumulate_cell_input);
+            _accumulate_cell_input.run();
         }
 
         if(_has_layer_norm)
@@ -984,7 +984,7 @@
     // Cell.
     CLScheduler::get().enqueue(_pixelwise_mul_forget_cell);
     CLScheduler::get().enqueue(_pixelwise_mul_input_cell);
-    CLScheduler::get().enqueue(_add_forget_cell);
+    _add_forget_cell.run();
     if(_has_cell_clipping)
     {
         _cell_clip.run();
@@ -995,12 +995,12 @@
     _input_to_output_outstage.run();
     _mm_recurrent_to_output.run();
     _recurrent_to_output_outstage.run();
-    CLScheduler::get().enqueue(_accumulate_input_recurrent_output);
+    _accumulate_input_recurrent_output.run();
     if(_has_peephole)
     {
         CLScheduler::get().enqueue(_pixelwise_mul_cell_to_output);
         _cell_to_output_outstage.run();
-        CLScheduler::get().enqueue(_accumulate_cell_to_output);
+        _accumulate_cell_to_output.run();
     }
 
     if(_has_layer_norm)
@@ -1026,7 +1026,7 @@
             _projection_output_to_accumulate_copy.run();
         }
 
-        CLScheduler::get().enqueue(_accumulate_projection);
+        _accumulate_projection.run();
 
         if(_projection_tensor_copy_required)
         {
@@ -1108,7 +1108,7 @@
             CLScheduler::get().enqueue(_projection_reduction);
             if(_projection_bias != nullptr)
             {
-                CLScheduler::get().enqueue(_projection_bias_add);
+                _projection_bias_add.run();
                 _projection_bias->mark_as_unused();
             }