COMPMID-1706: Fuse the bias addition within CLGEMM

Change-Id: I378f2023f4fa010f195f76716ac07aa86279bfae
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/280
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index b667621..2b004c2 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -48,8 +48,8 @@
 {
 using ElementsProcessed = Steps;
 
-inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
-                                 bool fp_mixed_precision)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float beta,
+                                 bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
     ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
@@ -61,9 +61,20 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D");
 
+    const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f;
+    const bool has_vec_c   = input2 != nullptr && beta != 0.f;
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(has_vec_c && !is_beta_one, "Adding input2 is only supported for beta equal to 1");
+
     if(!is_interleaved_transposed)
     {
         ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
+
+        if(has_vec_c)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2);
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor");
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->dimension(0) != input1->dimension(0), "Length of Vector C must match the number of columns of matrix B");
+        }
     }
     else
     {
@@ -101,6 +112,12 @@
 
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+
+        if(has_vec_c)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2);
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor");
+        }
     }
 
     if(output->total_size() != 0)
@@ -113,10 +130,11 @@
     return Status{};
 }
 
-inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output,
-                                                               bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target,
+inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output,
+                                                               float beta, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target,
                                                                ElementsProcessed &num_elements_processed)
 {
+    ARM_COMPUTE_UNUSED(beta);
     bool   window_changed = false;
     Window win{};
     Window win_out{};
@@ -126,6 +144,7 @@
     unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
     bool           reinterpret_input_as_3d             = reshape_info.reinterpret_input_as_3d();
     bool           reinterpret_output_as_3d            = (reshape_info.depth_output_gemm3d() != 0);
+    const bool     has_vec_c                           = input2 != nullptr && beta != 0.f;
 
     // In case both input and output have to be reinterpreted as 3D tensors,
     // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
@@ -176,6 +195,11 @@
 
         window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
                          update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
+        if(has_vec_c)
+        {
+            AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x);
+            window_changed = window_changed || update_window_and_padding(win, input2_access);
+        }
 
         output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
     }
@@ -209,6 +233,11 @@
 
         window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
                          update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
+        if(has_vec_c)
+        {
+            AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x);
+            window_changed = window_changed || update_window_and_padding(win, input2_access);
+        }
 
         Coordinates coord;
         coord.set_num_dimensions(output->num_dimensions());
@@ -227,20 +256,22 @@
 } // namespace
 
 CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
-    : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false)
+    : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _has_vec_c(false)
 {
 }
 
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
-                                           bool fp_mixed_precision)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta,
+                                           bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
 
     // Perform validate step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, fp_mixed_precision));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta,
+                                                  is_interleaved_transposed, reshape_info, fp_mixed_precision));
 
     _input0                   = input0;
     _input1                   = input1;
+    _input2                   = input2;
     _output                   = output;
     _reinterpret_input_as_3d  = reshape_info.reinterpret_input_as_3d();
     _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0);
@@ -266,7 +297,8 @@
     ElementsProcessed num_elements_processed{};
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, gpu_target, num_elements_processed);
+    auto win_config = validate_and_configure_window(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta, is_interleaved_transposed, reshape_info,
+                                                    gpu_target, num_elements_processed);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure_internal(win_config.second);
 
@@ -288,6 +320,8 @@
 
     const bool is_bifrost = get_arch_from_target(gpu_target) == GPUTarget::BIFROST;
 
+    _has_vec_c = input2 != nullptr && beta != 0.f;
+
     std::string kernel_name;
     if(is_interleaved_transposed)
     {
@@ -351,6 +385,9 @@
         build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
     }
 
+    // Configure matrix C addition if necessary
+    build_opts.add_option_if(_has_vec_c, "-DADD_VEC_C");
+
     // Create kernel
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
 
@@ -373,16 +410,18 @@
     _config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1)));
 }
 
-Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed,
-                                            const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
+Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta,
+                                            bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
 {
     // Note: num_elements_processed will be set in validate_and_configure_window()
     ElementsProcessed num_elements_processed{};
     ARM_COMPUTE_UNUSED(alpha);
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info, fp_mixed_precision));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, input2, output, beta, is_interleaved_transposed, reshape_info, fp_mixed_precision));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
                                                               input1->clone().get(),
+                                                              (input2 != nullptr) ? input2->clone().get() : nullptr,
                                                               output->clone().get(),
+                                                              beta,
                                                               is_interleaved_transposed,
                                                               reshape_info,
                                                               gpu_target,
@@ -409,10 +448,12 @@
     slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
     slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
 
+    const unsigned int num_arguments_vec_c = (_has_vec_c) ? num_arguments_per_1D_tensor() : 0;
+
     if(_reinterpret_input_as_3d)
     {
         // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
-        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3;
+        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + num_arguments_vec_c;
         const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
         _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
     }
@@ -420,7 +461,7 @@
     if(_reinterpret_output_as_3d)
     {
         // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
-        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
+        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + num_arguments_vec_c;
         const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
         _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
     }
@@ -438,6 +479,10 @@
         unsigned int idx = 0;
         add_2D_tensor_argument(idx, _input0, slice);
         add_2D_tensor_argument(idx, _input1, slice_b);
+        if(_has_vec_c)
+        {
+            add_1D_tensor_argument(idx, _input2, slice);
+        }
         add_2D_tensor_argument(idx, _output, slice);
         _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[2]));
         _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[2]));