COMPMID-661 Add optimal lws heuristics for the Bifrost direct_convolution kernels #45

Change-Id: I9e7ec5ed937fb4e8cab44a11c49a93f3aa01bedb
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110877
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index e1901af..3c5799f 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -131,6 +131,23 @@
         unsigned int num_elems_written_per_iteration_x = 0;
         unsigned int num_elems_written_per_iteration_y = 0;
 
+        // Through extensive experimentation with over 30 representative tensor
+        // shapes, we found a small number of local work size configurations
+        // that result in nearly optimal execution times. Selecting the right
+        // lws for a given shape, however, required a complex decision tree,
+        // until we constructed a simple feature as described below.
+        //
+        // We started from the number of multiply-accumulate operations for a
+        // convolution layer, which is equal to the product of the input
+        // dimensions 0..2 and the weights dimensions 0..2.  Unfortunately,
+        // this resulted in ties between distinct shapes that required distinct
+        // lws configurations. Replacing the width of the input with the kernel
+        // size, however, resulted in nearly optimal predictions. We use underscores
+        // in variable names to indicate when they are intentionally misleading.
+        const size_t product_of_weights_dimensions = weights->info()->dimension(0) * weights->info()->dimension(1) * weights->info()->dimension(2);
+        const size_t product_of_input_dimensions_  = input->info()->dimension(0) * weights->info()->dimension(1) * input->info()->dimension(2);
+        const float  mega_ops_                     = 1e-6 * product_of_weights_dimensions * product_of_input_dimensions_;
+
         switch(kernel_size)
         {
             case 1:
@@ -139,6 +156,18 @@
                 num_elems_read_per_iteration_y    = 4;
                 num_elems_written_per_iteration_x = 4;
                 num_elems_written_per_iteration_y = 4;
+                if(mega_ops_ < 1.f)
+                {
+                    _lws_hint = cl::NDRange(1, 1, 8);
+                }
+                else if(mega_ops_ < 7.f)
+                {
+                    _lws_hint = cl::NDRange(1, 1, 4);
+                }
+                else
+                {
+                    _lws_hint = cl::NDRange(1, 1, 2);
+                }
                 break;
             }
             case 3:
@@ -147,6 +176,22 @@
                 num_elems_read_per_iteration_y    = 5;
                 num_elems_written_per_iteration_x = 4;
                 num_elems_written_per_iteration_y = 3;
+                if(mega_ops_ < 1.f)
+                {
+                    _lws_hint = cl::NDRange(1, 1, 8);
+                }
+                else if(mega_ops_ < 13.f)
+                {
+                    _lws_hint = cl::NDRange(2, 1, 4);
+                }
+                else if(mega_ops_ < 50.f)
+                {
+                    _lws_hint = cl::NDRange(3, 1, 4);
+                }
+                else
+                {
+                    _lws_hint = cl::NDRange(2, 1, 6);
+                }
                 break;
             }
             case 5:
@@ -155,6 +200,14 @@
                 num_elems_read_per_iteration_y    = 6;
                 num_elems_written_per_iteration_x = 4;
                 num_elems_written_per_iteration_y = 2;
+                if(mega_ops_ < 2.f || mega_ops_ > 80.f)
+                {
+                    _lws_hint = cl::NDRange(2, 1, 4);
+                }
+                else
+                {
+                    _lws_hint = cl::NDRange(2, 1, 8);
+                }
                 break;
             }
             default: