COMPMID-477 - Optimized batched case in CLConvolutionLayer

Change-Id: I4ef18f49f1da0cb816aaa0762466b940792c15ed
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/84162
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 39526a2..684e323 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -48,13 +48,13 @@
 {
 }
 
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
 
-    if(output->info()->dimension(1) == 1)
+    if(!is_interleaved_transposed)
     {
         ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
     }
@@ -72,64 +72,36 @@
         _lws_hint = cl::NDRange(8, 8);
     }
 
-    std::ostringstream mm_arguments;
-    mm_arguments << "-DWIDTH_MATRIX_B=" << input1->info()->dimension(0) << " ";
+    std::set<std::string> build_opts;
+    build_opts.emplace(("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0))));
+    build_opts.emplace(("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0))));
+
     if(is_data_type_fixed_point(input0->info()->data_type()))
     {
-        mm_arguments << "-DALPHA=" << (input0->info()->data_type() == DataType::QS8 ?
-                                       sqcvt_qs8_f32(alpha, input0->info()->fixed_point_position()) :
-                                       sqcvt_qs16_f32(alpha, input0->info()->fixed_point_position()))
-                     << " ";
-        mm_arguments << "-DFIXED_POINT_POSITION=" << input0->info()->fixed_point_position() << " ";
+        build_opts.emplace(("-DALPHA=" + support::cpp11::to_string((input0->info()->data_type() == DataType::QS8 ?
+                                                                    sqcvt_qs8_f32(alpha, input0->info()->fixed_point_position()) :
+                                                                    sqcvt_qs16_f32(alpha, input0->info()->fixed_point_position())))));
+
+        build_opts.emplace(("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input0->info()->fixed_point_position())));
     }
     else
     {
-        mm_arguments << "-DALPHA=" << alpha << " ";
+        build_opts.emplace(("-DALPHA=" + float_to_string_with_full_precision(alpha)));
     }
-    std::set<std::string> build_opts;
 
-    // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
-    if(output->info()->dimension(1) == 1)
+    if(is_interleaved_transposed)
     {
-        mm_arguments << "-DWIDTH_VECTOR_A=" << input0->info()->dimension(0) << " ";
-        build_opts.emplace(mm_arguments.str());
-
-        // Create kernel
-        std::string data_type_name = lower_string(string_from_data_type(input0->info()->data_type()));
-        _kernel                    = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(("gemm_vm_" + data_type_name), build_opts));
-
-        // Configure window kernel
-        const unsigned int num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(input0->info()->data_type());
-
-        Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
-
-        AccessWindowStatic     input0_access(input0->info(), 0, 0, input0->info()->tensor_shape().x(), 1);
-        AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration_x);
-        AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration_x);
-
-        update_window_and_padding(win, input0_access, input1_access, output_access);
-
-        Coordinates coord;
-        coord.set_num_dimensions(output->info()->num_dimensions());
-        output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape()));
-
-        ICLKernel::configure(win);
-    }
-    else
-    {
-        build_opts.emplace(mm_arguments.str());
-
         // Create kernel
         std::string data_type_name = lower_string(string_from_data_type(input0->info()->data_type()));
 
         if(data_type_name == "f32")
         {
             GPUTarget arch_target = get_arch_from_target(get_target());
-            _kernel               = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_mm_f32_" + string_from_target(arch_target), build_opts));
+            _kernel               = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_mm_interleaved_transposed_f32_" + string_from_target(arch_target), build_opts));
         }
         else
         {
-            _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_mm_" + data_type_name, build_opts));
+            _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_mm_interleaved_transposed_" + data_type_name, build_opts));
         }
 
         // Configure window kernel
@@ -148,6 +120,44 @@
 
         ICLKernel::configure(win);
     }
+    else // The input tensors have not been reshaped
+    {
+        ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
+
+        // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor
+        const unsigned int num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(input0->info()->data_type());
+        const unsigned int num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->info()->dimension(1)), 4);
+
+        build_opts.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())));
+        build_opts.emplace(("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elems_processed_per_iteration_x)));
+        build_opts.emplace(("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elems_processed_per_iteration_y)));
+
+        // Create kernel
+        if(is_data_type_fixed_point(input0->info()->data_type()))
+        {
+            std::string kernel_name = "gemm_mm_" + lower_string(string_from_data_type(input0->info()->data_type()));
+            _kernel                 = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel((kernel_name), build_opts));
+        }
+        else
+        {
+            std::string kernel_name = "gemm_mm_floating_point";
+            _kernel                 = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel((kernel_name), build_opts));
+        }
+
+        Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+        AccessWindowStatic    input0_access(input0->info(), 0, 0, input0->info()->dimension(0), ceil_to_multiple(input0->info()->dimension(1), num_elems_processed_per_iteration_y));
+        AccessWindowStatic    input1_access(input1->info(), 0, 0, ceil_to_multiple(input1->info()->dimension(0), num_elems_processed_per_iteration_x), input1->info()->dimension(1));
+        AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+
+        update_window_and_padding(win, input0_access, input1_access, output_access);
+
+        Coordinates coord;
+        coord.set_num_dimensions(output->info()->num_dimensions());
+        output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape()));
+
+        ICLKernel::configure(win);
+    }
 }
 
 void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -157,9 +167,9 @@
 
     Window slice          = window.first_slice_window_2D();
     Window slice_matrix_b = slice;
-    slice_matrix_b.set(Window::DimX, Window::Dimension(0, _input1->info()->dimension(0), 1));
-    slice_matrix_b.set(Window::DimY, Window::Dimension(0, _input1->info()->dimension(1), 1));
-    slice_matrix_b.set(Window::DimZ, Window::Dimension(0, 1, 1));
+
+    slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
+    slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
 
     do
     {