COMPMID-817: Tuner: Port kernels to new design.

Change-Id: Iaabb1153c2abe0400ec79d51a21347debe92d642
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134062
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/kernels/CLCol2ImKernel.cpp b/src/core/CL/kernels/CLCol2ImKernel.cpp
index e15da72..4e44420 100644
--- a/src/core/CL/kernels/CLCol2ImKernel.cpp
+++ b/src/core/CL/kernels/CLCol2ImKernel.cpp
@@ -110,21 +110,6 @@
 
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("col2im", build_opts.options()));
 
-    // Configure the local work size for Bifrost with a value obtained
-    // via exhaustive autotuning over 30 representative tensor shapes.
-    const GPUTarget gpu_target = get_target();
-    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
-    {
-        if((_convolved_dims.first == 7) || (_convolved_dims.first == 14))
-        {
-            _lws_hint = cl::NDRange(1, 7, 1);
-        }
-        else
-        {
-            _lws_hint = cl::NDRange(1, 8, 1);
-        }
-    }
-
     // Configure kernel window
     auto win_config = validate_and_configure_window(input->info(), output->info(), _convolved_dims);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
diff --git a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
index 41ff220..c89b16e 100644
--- a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
@@ -90,15 +90,6 @@
 
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("depthwise_im2col", build_opts.options()));
 
-    // Configure the local work size for Bifrost with a value obtained
-    // via exhaustive autotuning for the MobileNets tensor shapes.
-    const GPUTarget gpu_target = get_target();
-
-    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
-    {
-        _lws_hint = cl::NDRange(1, 2, 1);
-    }
-
     // Configure  kernel window
     Window win = calculate_max_window(*output->info(), Steps());
     // CLDepthwiseIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 7a9760b..fc52f4e 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -194,51 +194,9 @@
     _output         = output;
     _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
 
-    const DataType data_type = input0->info()->data_type();
-    const int      fp_pos    = input0->info()->fixed_point_position();
-
-    // Get target architecture
-    GPUTarget gpu_target = get_target();
-
-    // Configure LWS hint
-    switch(gpu_target)
-    {
-        case GPUTarget::MIDGARD:
-        case GPUTarget::T600:
-        case GPUTarget::T700:
-        case GPUTarget::T800:
-            if(output->info()->dimension(1) == 196)
-            {
-                _lws_hint = cl::NDRange(1, 7);
-            }
-            else
-            {
-                _lws_hint = cl::NDRange(8, 8);
-            }
-            break;
-        case GPUTarget::G71:
-        case GPUTarget::G72:
-        case GPUTarget::G51:
-        case GPUTarget::G51BIG:
-        case GPUTarget::G51LIT:
-        case GPUTarget::TNOX:
-            if(input1->info()->dimension(1) == 24)
-            {
-                // LWS optimized for the 11x11 AlexNet convolution on Bifrost.
-                _lws_hint = cl::NDRange(2, 2);
-            }
-            else if(output->info()->dimension(1) == 196)
-            {
-                _lws_hint = cl::NDRange(1, 7);
-            }
-            else
-            {
-                _lws_hint = cl::NDRange(8, 8);
-            }
-            break;
-        default:
-            _lws_hint = cl::NullRange;
-    }
+    const DataType  data_type  = input0->info()->data_type();
+    const int       fp_pos     = input0->info()->fixed_point_position();
+    const GPUTarget gpu_target = get_target();
 
     ElementsProcessed num_elements_processed{};
 
diff --git a/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp
index 1d6f388..d8ecd50 100644
--- a/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp
@@ -110,14 +110,6 @@
         _kernel.setArg<int>(idx++, -_input1->info()->quantization_info().offset);
     }
 
-    // Configure the local work size for Bifrost with a value obtained
-    // via exhaustive autotuning for the MobileNets tensor shapes.
-    const GPUTarget gpu_target = get_target();
-    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
-    {
-        _lws_hint = cl::NDRange(1, 1, 1);
-    }
-
     // Configure kernel window
     const unsigned int num_elems_read_per_iteration = 4;
 
diff --git a/src/core/CL/kernels/CLIm2ColKernel.cpp b/src/core/CL/kernels/CLIm2ColKernel.cpp
index 378456c..53a4dca 100644
--- a/src/core/CL/kernels/CLIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLIm2ColKernel.cpp
@@ -61,7 +61,7 @@
 } // namespace
 
 CLIm2ColKernel::CLIm2ColKernel()
-    : _input(nullptr), _output(nullptr), _convolved_dims(), _num_elems_processed_per_iteration(1), _run_func(nullptr), _kernel_dims()
+    : _input(nullptr), _output(nullptr), _conv_info(), _convolved_dims(), _num_elems_processed_per_iteration(1), _run_func(nullptr), _kernel_dims()
 {
 }
 
@@ -74,6 +74,7 @@
 
     _input       = input;
     _output      = output;
+    _conv_info   = conv_info;
     _kernel_dims = kernel_dims;
 
     const DataType  data_type  = input->info()->data_type();
@@ -190,10 +191,9 @@
                 {
                     vector_size = kernel_dims.width;
                 }
-                // Local work size and vector size optimized for the 11x11 AlexNet convolution on Bifrost.
+                // Vector size optimized for the 11x11 AlexNet convolution on Bifrost.
                 if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && kernel_dims.width == 11)
                 {
-                    _lws_hint   = cl::NDRange(1, 1, 1);
                     vector_size = 8;
                 }
                 const size_t width_mod_vector_size = kernel_dims.width % vector_size;
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
index 3091df4..b242c55 100644
--- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp
+++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
@@ -208,8 +208,7 @@
     _output    = output;
     _pool_info = pool_info;
 
-    const GPUTarget gpu_target = get_target();
-    const DataType  data_type  = input->info()->data_type();
+    const DataType data_type = input->info()->data_type();
 
     // Set build options
     CLBuildOptions build_opts;
@@ -273,20 +272,11 @@
     ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
     ICLKernel::configure(std::get<1>(win_config));
 
-    // Configure the local work size (hint) from the first two dimensions of the global work size.
-    // On Bifrost, this works for up to 35x35xC filters, for which the pooling_layer_3_optimized
-    // kernel is launched with gws=(9, 33, C). In any case, the hint will be ignored if it is
-    // invalid (e.g. exceeds the maximum workgroup size that the kernel can be launched with).
     if(data_layout == DataLayout::NCHW)
     {
         CLPoolingConfig pooling_config     = std::get<2>(win_config);
         _num_elems_processed_per_iteration = pooling_config.first;
         _border_size                       = pooling_config.second;
-        if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
-        {
-            cl::NDRange gws = ICLKernel::gws_from_window(std::get<1>(win_config));
-            _lws_hint       = cl::NDRange(gws[0], gws[1], 1);
-        }
     }
     else
     {
diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
index 676a121..c2b24e3 100644
--- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
@@ -134,6 +134,7 @@
     _input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
     _im2col_kernel.set_target(gpu_target);
     _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier);
+    CLScheduler::get().tune_kernel_static(_im2col_kernel);
 
     // Weights reshape configuration
     const TensorShape shape_weights_reshape(patch_size, weights_z);
@@ -149,6 +150,7 @@
     _v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out));
     _v2mm_kernel.set_target(gpu_target);
     _v2mm_kernel.configure(&_input_reshaped, &_weights_reshaped, &_v2mm_output);
+    CLScheduler::get().tune_kernel_static(_v2mm_kernel);
     _output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape));
     _vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output, conv_w, conv_h);
 
diff --git a/src/runtime/CL/functions/CLFlattenLayer.cpp b/src/runtime/CL/functions/CLFlattenLayer.cpp
index 9f571b2..f5809a2 100644
--- a/src/runtime/CL/functions/CLFlattenLayer.cpp
+++ b/src/runtime/CL/functions/CLFlattenLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,6 +25,7 @@
 
 #include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
 #include "arm_compute/core/Size2D.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
 #include "support/ToolchainSupport.h"
 
 using namespace arm_compute;
@@ -34,4 +35,5 @@
     auto k = arm_compute::support::cpp14::make_unique<CLIm2ColKernel>();
     k->configure(input, output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
     _kernel = std::move(k);
+    CLScheduler::get().tune_kernel_static(*_kernel);
 }
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 44bf283..9248bc5 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -117,6 +117,7 @@
     // Configure im2col kernel
     _memory_group.manage(&_im2col_output);
     _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
+    CLScheduler::get().tune_kernel_static(_im2col_kernel);
 
     // Configure matrix multiply kernel
     configure_mm(&_im2col_output, weights, output);
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 7f37520..a0ec66f 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -143,7 +143,9 @@
         _transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
     }
 
+    // Configure and tune matrix multiply kernel
     _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height));
+    CLScheduler::get().tune_kernel_static(_mm_kernel);
 
     if(_is_interleaved_transposed)
     {
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 4f87043..27bed44 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -230,10 +230,11 @@
     _gemm_output.allocator()->init(info_gemm);
     _memory_group.manage(&_gemm_output);
 
-    // Configure im2col
+    // Configure and tune im2col
     _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation);
+    CLScheduler::get().tune_kernel_static(_im2col_kernel);
 
-    // Configure GEMM
+    // Configure and tune GEMM
     configure_mm(&_im2col_output, weights, &_gemm_output);
 
     _im2col_output.allocator()->allocate();
@@ -250,8 +251,9 @@
         _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset);
     }
 
-    // Configure Col2Im
+    // Configure and tune Col2Im
     _col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, std::make_pair(conv_w, conv_h));
+    CLScheduler::get().tune_kernel_static(_col2im_kernel);
     if(_is_quantized)
     {
         _tmp_output.allocator()->allocate();
diff --git a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp
index 986fe00..31d5cd5 100644
--- a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp
@@ -163,6 +163,8 @@
     _weights_reshaped.allocator()->allocate();
     _input_im2col_reshaped.allocator()->allocate();
     _gemm_output.allocator()->allocate();
+
+    CLScheduler::get().tune_kernel_static(_input_im2col_kernel);
 }
 
 void CLLocallyConnectedLayer::run()
diff --git a/src/runtime/CL/functions/CLPoolingLayer.cpp b/src/runtime/CL/functions/CLPoolingLayer.cpp
index 17875a3..cbe1ce3 100644
--- a/src/runtime/CL/functions/CLPoolingLayer.cpp
+++ b/src/runtime/CL/functions/CLPoolingLayer.cpp
@@ -63,6 +63,9 @@
             ARM_COMPUTE_ERROR("Data layout not supported");
     }
     _border_handler.configure(input, _kernel->border_size(), border_mode, pixel_value);
+
+    // Tune kernels
+    CLScheduler::get().tune_kernel_static(*_kernel);
 }
 
 Status CLPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info)
diff --git a/src/runtime/CL/tuners/BifrostTuner.cpp b/src/runtime/CL/tuners/BifrostTuner.cpp
index c0ebd24..edd074b 100644
--- a/src/runtime/CL/tuners/BifrostTuner.cpp
+++ b/src/runtime/CL/tuners/BifrostTuner.cpp
@@ -124,15 +124,163 @@
         k.set_lws_hint(lws_hint);
     }
 }
+
+void tune_col2im_kernel(CLCol2ImKernel &k)
+{
+    cl::NDRange     lws_hint   = k.lws_hint();
+    const GPUTarget gpu_target = k.get_target();
+
+    // Configure the local work size for Bifrost with a value obtained
+    // via exhaustive autotuning over 30 representative tensor shapes.
+    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
+    {
+        if((k._convolved_dims.first == 7) || (k._convolved_dims.first == 14))
+        {
+            lws_hint = cl::NDRange(1, 7, 1);
+        }
+        else
+        {
+            lws_hint = cl::NDRange(1, 8, 1);
+        }
+    }
+
+    k.set_lws_hint(lws_hint);
+}
+
+void tune_im2col_kernel(CLIm2ColKernel &k)
+{
+    cl::NDRange     lws_hint   = k.lws_hint();
+    const GPUTarget gpu_target = k.get_target();
+
+    // Local work size optimized for the 11x11 AlexNet convolution on Bifrost.
+    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && k._kernel_dims.width == 11)
+    {
+        const bool is_square_kernel = (k._kernel_dims.width == k._kernel_dims.height);
+        if(!is_square_kernel && k._kernel_dims.width > 1 && !k._conv_info.has_padding())
+        {
+            lws_hint = cl::NDRange(1, 1, 1);
+        }
+    }
+    k.set_lws_hint(lws_hint);
+}
+
+void tune_depthwise_im2col_kernel(CLDepthwiseIm2ColKernel &k)
+{
+    cl::NDRange     lws_hint   = k.lws_hint();
+    const GPUTarget gpu_target = k.get_target();
+
+    // Configure the local work size for Bifrost with a value obtained
+    // via exhaustive autotuning for the MobileNets tensor shapes.
+    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
+    {
+        lws_hint = cl::NDRange(1, 2, 1);
+    }
+
+    k.set_lws_hint(lws_hint);
+}
+
+void tune_gemv_kernel(CLGEMMMatrixVectorMultiplyKernel &k)
+{
+    cl::NDRange     lws_hint   = k.lws_hint();
+    const GPUTarget gpu_target = k.get_target();
+
+    // Configure the local work size for Bifrost with a value obtained
+    // via exhaustive autotuning for the MobileNets tensor shapes.
+    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
+    {
+        lws_hint = cl::NDRange(1, 1, 1);
+    }
+
+    k.set_lws_hint(lws_hint);
+}
+
+void tune_gemm_kernel(CLGEMMMatrixMultiplyKernel &k)
+{
+    cl::NDRange     lws_hint   = k.lws_hint();
+    const GPUTarget gpu_target = k.get_target();
+
+    // Configure LWS hint
+    switch(gpu_target)
+    {
+        case GPUTarget::G71:
+        case GPUTarget::G72:
+        case GPUTarget::G51:
+        case GPUTarget::G51BIG:
+        case GPUTarget::G51LIT:
+        case GPUTarget::TNOX:
+            if(k._input1->info()->dimension(1) == 24)
+            {
+                // LWS optimized for the 11x11 AlexNet convolution on Bifrost.
+                lws_hint = cl::NDRange(2, 2);
+            }
+            else if(k._output->info()->dimension(1) == 196)
+            {
+                lws_hint = cl::NDRange(1, 7);
+            }
+            else
+            {
+                lws_hint = cl::NDRange(8, 8);
+            }
+            break;
+        default:
+            lws_hint = cl::NullRange;
+    }
+
+    k.set_lws_hint(lws_hint);
+}
+
+void tune_pooling_kernel(CLPoolingLayerKernel &k)
+{
+    cl::NDRange     lws_hint   = k.lws_hint();
+    const GPUTarget gpu_target = k.get_target();
+
+    // Configure the local work size (hint) from the first two dimensions of the global work size.
+    // On Bifrost, this works for up to 35x35xC filters, for which the pooling_layer_3_optimized
+    // kernel is launched with gws=(9, 33, C). In any case, the hint will be ignored if it is
+    // invalid (e.g. exceeds the maximum workgroup size that the kernel can be launched with).
+    if(k._input->info()->data_layout() == DataLayout::NCHW)
+    {
+        if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
+        {
+            cl::NDRange gws = ICLKernel::gws_from_window(k.window());
+            lws_hint        = cl::NDRange(gws[0], gws[1], 1);
+        }
+    }
+
+    k.set_lws_hint(lws_hint);
+}
 } // namespace
 
 void BifrostTuner::tune_kernel_static(ICLKernel &kernel)
 {
-    // Continue on tuning if dynamic tuning
     if(dynamic_cast<CLDirectConvolutionLayerKernel *>(&kernel) != nullptr)
     {
         tune_direct_convolution_kernel(*utils::cast::polymorphic_downcast<CLDirectConvolutionLayerKernel *>(&kernel));
     }
+    else if(dynamic_cast<CLCol2ImKernel *>(&kernel) != nullptr)
+    {
+        tune_col2im_kernel(*utils::cast::polymorphic_downcast<CLCol2ImKernel *>(&kernel));
+    }
+    else if(dynamic_cast<CLIm2ColKernel *>(&kernel) != nullptr)
+    {
+        tune_im2col_kernel(*utils::cast::polymorphic_downcast<CLIm2ColKernel *>(&kernel));
+    }
+    else if(dynamic_cast<CLDepthwiseIm2ColKernel *>(&kernel) != nullptr)
+    {
+        tune_depthwise_im2col_kernel(*utils::cast::polymorphic_downcast<CLDepthwiseIm2ColKernel *>(&kernel));
+    }
+    else if(dynamic_cast<CLGEMMMatrixVectorMultiplyKernel *>(&kernel) != nullptr)
+    {
+        tune_gemv_kernel(*utils::cast::polymorphic_downcast<CLGEMMMatrixVectorMultiplyKernel *>(&kernel));
+    }
+    else if(dynamic_cast<CLGEMMMatrixMultiplyKernel *>(&kernel) != nullptr)
+    {
+        tune_gemm_kernel(*utils::cast::polymorphic_downcast<CLGEMMMatrixMultiplyKernel *>(&kernel));
+    }
+    else if(dynamic_cast<CLPoolingLayerKernel *>(&kernel) != nullptr)
+    {
+        tune_pooling_kernel(*utils::cast::polymorphic_downcast<CLPoolingLayerKernel *>(&kernel));
+    }
 }
 
 void BifrostTuner::tune_kernel_dynamic(ICLKernel &kernel)
diff --git a/src/runtime/CL/tuners/MidgardTuner.cpp b/src/runtime/CL/tuners/MidgardTuner.cpp
new file mode 100644
index 0000000..2c4b1ac
--- /dev/null
+++ b/src/runtime/CL/tuners/MidgardTuner.cpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/CL/tuners/MIdgardTuner.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernels.h"
+#include "arm_compute/core/utils/misc/Cast.h"
+
+namespace arm_compute
+{
+namespace tuners
+{
+namespace
+{
+void tune_gemm_kernel(CLGEMMMatrixMultiplyKernel &k)
+{
+    cl::NDRange     lws_hint   = k.lws_hint();
+    const GPUTarget gpu_target = k.get_target();
+
+    switch(gpu_target)
+    {
+        case GPUTarget::MIDGARD:
+        case GPUTarget::T600:
+        case GPUTarget::T700:
+        case GPUTarget::T800:
+            if(k._output->info()->dimension(1) == 196)
+            {
+                lws_hint = cl::NDRange(1, 7);
+            }
+            else
+            {
+                lws_hint = cl::NDRange(8, 8);
+            }
+            break;
+        default:
+            lws_hint = cl::NullRange;
+    }
+
+    k.set_lws_hint(lws_hint);
+}
+} // namespace
+
+void MidgardTuner::tune_kernel_static(ICLKernel &kernel)
+{
+    if(dynamic_cast<CLGEMMMatrixMultiplyKernel *>(&kernel) != nullptr)
+    {
+        tune_gemm_kernel(*utils::cast::polymorphic_downcast<CLGEMMMatrixMultiplyKernel *>(&kernel));
+    }
+}
+
+void MidgardTuner::tune_kernel_dynamic(ICLKernel &kernel)
+{
+    ARM_COMPUTE_UNUSED(kernel);
+}
+} // namespace tuners
+} // namespace arm_compute
\ No newline at end of file