COMPMID-816 - Enabled CLConvolutionLayer to use CLGEMM function instead
of CLGEMMMatrixMultiplyKernel kernel.

Change-Id: If035fa3d1fb3ff4012442bcd908c370d21aa6657
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/115990
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/runtime/CL/functions/CLConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLConvolutionLayer.h
index 3fe6604..f6672ce 100644
--- a/arm_compute/runtime/CL/functions/CLConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLConvolutionLayer.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,6 +36,7 @@
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/CL/CLMemoryGroup.h"
 #include "arm_compute/runtime/CL/CLTensor.h"
+#include "arm_compute/runtime/CL/functions/CLGEMM.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
 #include "arm_compute/runtime/IMemoryManager.h"
@@ -76,15 +77,20 @@
     bool                     _transpose1xW;
 };
 
-/** Basic function to compute the convolution layer. This function calls the following OpenCL kernels:
+/** Basic function to compute the convolution layer. This function calls the following OpenCL kernels/functions:
  *
- * -# @ref CLWeightsReshapeKernel (executed only once for each configuration)
- * -# @ref CLGEMMTranspose1xWKernel (executed only once for each configuration)
+ * Note: weights already reshaped for quantized asymmetric is not supported
+ *
  * -# @ref CLIm2ColKernel
- * -# @ref CLGEMMInterleave4x4Kernel
- * -# @ref CLGEMMMatrixMultiplyKernel or @ref CLGEMMLowpMatrixMultiplyCore (if quantized asymmetric)
+ * -# @ref CLGEMMLowpMatrixMultiplyCore (if quantized asymmetric)
  * -# @ref CLGEMMLowpQuantizeDownInt32ToUint8Scale (if quantized asymmetric)
  * -# @ref CLCol2ImKernel
+ *
+ * if the weights are already reshaped:
+ * -# @ref CLGEMMInterleave4x4Kernel
+ * -# @ref CLGEMMMatrixMultiplyKernel
+ * else
+ * -# @ref CLGEMM
  */
 class CLConvolutionLayer : public IFunction
 {
@@ -119,20 +125,21 @@
      *                                                 except for input of QASYMM8 type where output should be of S32 type.
      * @param is_interleaved_transposed Flag that signals if matrix is interleaved transposed
      */
-    void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed = true);
+    void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed, bool are_weights_reshaped);
 
 private:
     CLMemoryGroup                                       _memory_group;
     CLConvolutionLayerReshapeWeights                    _reshape_weights;
-    CLIm2ColKernel                                      _input_im2col_kernel;
-    CLGEMMInterleave4x4Kernel                           _input_interleave_kernel;
+    CLIm2ColKernel                                      _im2col_kernel;
+    CLGEMMInterleave4x4Kernel                           _interleave_kernel;
     CLGEMMMatrixMultiplyKernel                          _mm_kernel;
+    CLGEMM                                              _mm_gemm;
     CLGEMMLowpMatrixMultiplyCore                        _mm_gemmlowp;
     CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
-    CLCol2ImKernel                                      _output_col2im_kernel;
+    CLCol2ImKernel                                      _col2im_kernel;
 
-    CLTensor _input_im2col_reshaped;
-    CLTensor _input_interleaved_reshaped;
+    CLTensor _im2col_output;
+    CLTensor _interleave_output;
     CLTensor _weights_reshaped;
     CLTensor _weights_transposed;
     CLTensor _gemm_output;
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index 2c1ddc3..d115397 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -87,7 +87,6 @@
 {
     _memory_group.acquire();
 
-    cl::CommandQueue q = CLScheduler::get().queue();
     CLScheduler::get().enqueue(_weights_reshape_kernel);
     if(_transpose1xW)
     {
@@ -98,33 +97,49 @@
 }
 
 CLConvolutionLayer::CLConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _reshape_weights(), _input_im2col_kernel(), _input_interleave_kernel(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _output_col2im_kernel(),
-      _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _weights_transposed(), _gemm_output(), _tmp_output(), _are_weights_reshaped(false), _is_quantized(false),
+    : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _interleave_kernel(), _mm_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
+      _col2im_kernel(), _im2col_output(), _interleave_output(), _weights_reshaped(), _weights_transposed(), _gemm_output(), _tmp_output(), _are_weights_reshaped(false), _is_quantized(false),
       _is_interleaved_transposed(false)
 {
 }
 
-void CLConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed)
+void CLConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed, bool are_weights_reshaped)
 {
     if(_is_quantized)
     {
-        // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
-        // Extract and negate input and weights offset
-        const QuantizationInfo input_quantization_info   = input->info()->quantization_info();
-        const QuantizationInfo weights_quantization_info = weights->info()->quantization_info();
+        if(are_weights_reshaped)
+        {
+            ARM_COMPUTE_ERROR("Weights already reshaped are not suppported with gemmlowp");
+        }
+        else
+        {
+            // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
+            // Extract and negate input and weights offset
+            const QuantizationInfo input_quantization_info   = input->info()->quantization_info();
+            const QuantizationInfo weights_quantization_info = weights->info()->quantization_info();
 
-        input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.scale, -input_quantization_info.offset));
-        weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset));
+            input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.scale, -input_quantization_info.offset));
+            weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset));
 
-        _mm_gemmlowp.configure(input, weights, output, GEMMInfo(false, false, true /* Reshape weights only for the first run*/));
+            _mm_gemmlowp.configure(input, weights, output, GEMMInfo(false, false, true /* Reshape weights only for the first run*/));
 
-        // Revert back QuantizatioInfo as input and weights could be used in other convolution layers
-        input->info()->set_quantization_info(input_quantization_info);
-        weights->info()->set_quantization_info(weights_quantization_info);
+            // Revert back QuantizatioInfo as input and weights could be used in other convolution layers
+            input->info()->set_quantization_info(input_quantization_info);
+            weights->info()->set_quantization_info(weights_quantization_info);
+        }
     }
     else
     {
-        _mm_kernel.configure(input, weights, output, 1.f, is_interleaved_transposed);
+        if(are_weights_reshaped)
+        {
+            // Configure matrix multiply kernel
+            _mm_kernel.configure(input, weights, output, 1.f, is_interleaved_transposed);
+        }
+        else
+        {
+            // Configure matrix multiply function
+            _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/));
+        }
     }
 }
 
@@ -133,6 +148,7 @@
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights);
+    ARM_COMPUTE_ERROR_ON(weights_info.are_reshaped() && CLScheduler::get().target() == GPUTarget::BIFROST);
     ARM_COMPUTE_ERROR_ON(!weights_info.are_reshaped() && weights->info()->dimension(2) != input->info()->dimension(2));
     ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
     ARM_COMPUTE_ERROR_ON(weights_info.are_reshaped() && is_data_type_quantized_asymmetric(input->info()->data_type()));
@@ -158,8 +174,8 @@
 
     // Set the GPU target for matrix multiply and im2col and col2im
     _mm_kernel.set_target(CLScheduler::get().target());
-    _input_im2col_kernel.set_target(CLScheduler::get().target());
-    _output_col2im_kernel.set_target(CLScheduler::get().target());
+    _im2col_kernel.set_target(CLScheduler::get().target());
+    _col2im_kernel.set_target(CLScheduler::get().target());
 
     const bool append_bias = (biases != nullptr) && (!_is_quantized);
     _are_weights_reshaped  = weights_info.are_reshaped();
@@ -183,7 +199,7 @@
 
     // Check if its a "fully connected" convolution
     const bool is_fully_connected_convolution = ((conv_w == 1) && (conv_h == 1));
-    _is_interleaved_transposed                = (!is_fully_connected_convolution && !_is_quantized);
+    _is_interleaved_transposed                = (!is_fully_connected_convolution) && (!_is_quantized) && (_are_weights_reshaped);
 
     unsigned int mat_weights_cols = weights->info()->dimension(3);
     unsigned int mat_weights_rows = weights->info()->dimension(0) * weights->info()->dimension(1) * weights->info()->dimension(2) + bias_element;
@@ -205,8 +221,9 @@
     }
     else
     {
-        // _weights_reshaped will be auto configured in the kernel
-        _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped, _is_interleaved_transposed /* 1xW transpose */);
+        // _weights_reshaped will be auto configured in the kernel.
+        // Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM
+        _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped, false);
 
         weights = &_weights_reshaped;
     }
@@ -221,11 +238,11 @@
     // FIXME: input->clone() doesn't work with subtensors for grouped convolutions.
     TensorInfo im2col_reshaped_info(shape_im2col, 1, dt, input->info()->fixed_point_position());
     im2col_reshaped_info.set_quantization_info(input->info()->quantization_info());
-    _input_im2col_reshaped.allocator()->init(im2col_reshaped_info);
-    _memory_group.manage(&_input_im2col_reshaped);
+    _im2col_output.allocator()->init(im2col_reshaped_info);
+    _memory_group.manage(&_im2col_output);
 
     // Create GEMM output tensor
-    TensorShape shape_gemm = _input_im2col_reshaped.info()->tensor_shape();
+    TensorShape shape_gemm = _im2col_output.info()->tensor_shape();
     shape_gemm.set(0, mat_weights_cols);
     shape_gemm.set(1, mat_input_rows);
     const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt;
@@ -237,24 +254,24 @@
     _memory_group.manage(&_gemm_output);
 
     // Configure im2col
-    _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, append_bias);
+    _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias);
 
     // Configure matrix multiply
     if(_is_interleaved_transposed)
     {
         // Configure GEMMInterleave4x4. _input_interleaved_reshaped will be auto configured in the kernel
-        _input_interleave_kernel.configure(&_input_im2col_reshaped, &_input_interleaved_reshaped);
-        _memory_group.manage(&_input_interleaved_reshaped);
+        _interleave_kernel.configure(&_im2col_output, &_interleave_output);
+        _memory_group.manage(&_interleave_output);
 
         // Configure GEMM
-        configure_mm(&_input_interleaved_reshaped, weights, &_gemm_output);
-        _input_interleaved_reshaped.allocator()->allocate();
+        configure_mm(&_interleave_output, weights, &_gemm_output, true, _are_weights_reshaped);
+        _interleave_output.allocator()->allocate();
     }
     else
     {
-        configure_mm(&_input_im2col_reshaped, weights, &_gemm_output, false);
+        configure_mm(&_im2col_output, weights, &_gemm_output, false, _are_weights_reshaped);
     }
-    _input_im2col_reshaped.allocator()->allocate();
+    _im2col_output.allocator()->allocate();
 
     // Configure output stage for quantized case
     if(_is_quantized)
@@ -267,7 +284,7 @@
     }
 
     // Configure Col2Im
-    _output_col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, std::make_pair(conv_w, conv_h));
+    _col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, std::make_pair(conv_w, conv_h));
     if(_is_quantized)
     {
         _tmp_output.allocator()->allocate();
@@ -298,32 +315,39 @@
     _memory_group.acquire();
 
     // Run im2col
-    CLScheduler::get().enqueue(_input_im2col_kernel);
+    CLScheduler::get().enqueue(_im2col_kernel);
 
+    // Note: _is_interleaved_transposed is true only if the weights passed to the function have been passed already reshaped
+    //       and if we do not have QASYMM8 data type. If this flag is true, we need to run the
+    //       gemm kernel instead of gemm function
     if(_is_interleaved_transposed)
     {
-        // Run interleave4x4
-        CLScheduler::get().enqueue(_input_interleave_kernel);
-    }
+        // Run interleave4x4 kernel
+        CLScheduler::get().enqueue(_interleave_kernel);
 
-    // Runs matrix multiply on reshaped matrices
-    if(_is_quantized)
-    {
-        _mm_gemmlowp.run();
+        // Run matrix multiply kernel
+        CLScheduler::get().enqueue(_mm_kernel);
     }
     else
     {
-        CLScheduler::get().enqueue(_mm_kernel);
-    }
+        // Runs CLGEMM or CLGEMMLowpMatrixMultiplyCore functions
+        if(_is_quantized)
+        {
+            // Run gemmlowp
+            _mm_gemmlowp.run();
 
-    // Run output stage for quantized case
-    if(_is_quantized)
-    {
-        _gemmlowp_output_stage.run();
+            // Run output stage
+            _gemmlowp_output_stage.run();
+        }
+        else
+        {
+            // Run gemm
+            _mm_gemm.run();
+        }
     }
 
     // Reshape output matrix
-    CLScheduler::get().enqueue(_output_col2im_kernel, false);
+    CLScheduler::get().enqueue(_col2im_kernel, false);
 
     _memory_group.release();
 }
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 1657bdc..c676a10 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -80,7 +80,7 @@
         matrix_a = &_tmp_a;
         matrix_b = &_tmp_b;
 
-        // _tmp_a and _tmp_n will be auto configured in _interleave_kernel and in _transpose_kernel
+        // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
 
         // Configure interleave kernel
         _interleave_kernel.configure(a, &_tmp_a);
diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp
index 56e10f0..8b9db91 100644
--- a/tests/validation/CL/ConvolutionLayer.cpp
+++ b/tests/validation/CL/ConvolutionLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -114,7 +114,7 @@
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmall, CLConvolutionLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                     framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                                     framework::dataset::make("ReshapeWeights", { true })),
                                                                                                              framework::dataset::make("DataType",
                                                                                                                      DataType::F16)))
 {
@@ -122,7 +122,7 @@
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLConvolutionLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeConvolutionLayerDataset(),
-                                                                                                                   framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                                   framework::dataset::make("ReshapeWeights", { true })),
                                                                                                            framework::dataset::make("DataType",
                                                                                                                    DataType::F16)))
 {
@@ -133,7 +133,7 @@
 
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmall, CLConvolutionLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                      framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                                      framework::dataset::make("ReshapeWeights", { true })),
                                                                                                               framework::dataset::make("DataType",
                                                                                                                       DataType::F32)))
 {
@@ -141,7 +141,7 @@
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLConvolutionLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeConvolutionLayerDataset(),
-                                                                                                                    framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                                    framework::dataset::make("ReshapeWeights", { true })),
                                                                                                             framework::dataset::make("DataType",
                                                                                                                     DataType::F32)))
 {
@@ -158,7 +158,7 @@
 TEST_SUITE(QS8)
 // We test for fixed point precision [4,6]
 FIXTURE_DATA_TEST_CASE(RunSmall, CLConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                       framework::dataset::make("ReshapeWeights", { true, false })),
+                       framework::dataset::make("ReshapeWeights", { true })),
                        framework::dataset::make("DataType",
                                                 DataType::QS8)),
                        framework::dataset::make("FractionalBits", 4, 7)))
@@ -167,7 +167,7 @@
     validate(CLAccessor(_target), _reference, tolerance_fixed);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeConvolutionLayerDataset(),
-                                                                                                                       framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                                       framework::dataset::make("ReshapeWeights", { true })),
                                                                                                                        framework::dataset::make("DataType",
                                                                                                                                DataType::QS8)),
                                                                                                                        framework::dataset::make("FractionalBits", 4, 7)))
@@ -180,7 +180,7 @@
 TEST_SUITE(QS16)
 // Testing for fixed point position [1,14)
 FIXTURE_DATA_TEST_CASE(RunSmall, CLConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                       framework::dataset::make("ReshapeWeights", { true, false })),
+                       framework::dataset::make("ReshapeWeights", { true })),
                        framework::dataset::make("DataType",
                                                 DataType::QS16)),
                        framework::dataset::make("FractionalBits", 1, 14)))
@@ -189,7 +189,7 @@
     validate(CLAccessor(_target), _reference, tolerance_fixed);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeConvolutionLayerDataset(),
-                                                                                                                        framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                                        framework::dataset::make("ReshapeWeights", { true })),
                                                                                                                         framework::dataset::make("DataType",
                                                                                                                                 DataType::QS16)),
                                                                                                                         framework::dataset::make("FractionalBits", 1, 14)))