COMPMID-1276 - Allow GEMM to work with 3D input tensor

Skipped im2col in CLGEMMConvolutionLayer for 1x1 convolutions with NHWC data layout

Change-Id: I894e6b952ed8605e8f3ffc0ffc25c24730d4664c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141909
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 1f4df4f..1d1b17b 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -102,7 +102,8 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    const int m                         = a->info()->dimension(1);
+    bool      reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const int m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
     const int n                         = b->info()->dimension(0);
     const int k                         = a->info()->dimension(0);
     const int depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
@@ -118,6 +119,12 @@
     // Check if we need to reshape the matrix A and matrix B
     _is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
 
+    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(_is_interleaved_transposed)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
     if(_is_interleaved_transposed)
     {
         matrix_a = &_tmp_a;
@@ -132,14 +139,15 @@
         // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
 
         // Configure interleave kernel
-        _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height);
+        _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
 
         // Configure transpose kernel
         _transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
     }
 
     // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d));
+    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
+                                                                                                        reinterpret_input_as_3d));
     CLScheduler::get().tune_kernel_static(_mm_kernel);
 
     if(_is_interleaved_transposed)
@@ -180,11 +188,13 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    const int m                         = a->dimension(1);
+    bool      reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const int m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
     const int n                         = b->dimension(0);
     const int k                         = a->dimension(0);
     int       mult_transpose1xW_width   = 1;
     int       mult_interleave4x4_height = 1;
+    const int depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
 
     if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
     {
@@ -192,19 +202,25 @@
         mult_interleave4x4_height = 2;
     }
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
-
     // Check if we need to reshape the matrix A and matrix B
     const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
 
+    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(run_interleave_transpose)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+
     if(run_interleave_transpose)
     {
         matrix_a_info = &tmp_a_info;
         matrix_b_info = &tmp_b_info;
 
         // Validate interleave kernel
-        auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height)));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height));
+        auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
 
         // Validate transpose kernel
         auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index f1d2924..de62829 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -91,15 +91,15 @@
 
 CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(),
-      _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _skip_im2col(false), _is_quantized(false),
-      _is_activationlayer_enabled(false), _is_prepared(false)
+      _add_bias_kernel(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
+      _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
 {
 }
 
 void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info()));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), _skip_im2col));
 
     if(_is_quantized)
     {
@@ -120,15 +120,16 @@
     else
     {
         // Configure matrix multiply function
-        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth));
+        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
+                                                                                 _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
     }
 }
 
-Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth)
+Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth, bool skip_im2col)
 {
     const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
 
-    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth);
+    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
     if(is_quantized)
     {
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -180,7 +181,8 @@
     _original_weights = weights;
     _is_quantized     = is_data_type_quantized_asymmetric(input->info()->data_type());
     _data_layout      = data_layout;
-    _skip_im2col      = false;
+    _skip_im2col      = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !_is_quantized;
+    _append_bias      = (biases != nullptr) && (!_is_quantized);
 
     // Set the GPU target for im2col and col2im
     _im2col_kernel.set_target(CLScheduler::get().target());
@@ -191,9 +193,8 @@
     ICLTensor       *gemm_output_to_use        = output;
     ICLTensor       *gemm_output_staged_to_use = output;
 
-    const bool       append_bias   = (biases != nullptr) && (!_is_quantized);
-    const unsigned   bias_element  = (append_bias) ? 1 : 0;
-    const ICLTensor *biases_to_use = (append_bias) ? biases : nullptr;
+    const unsigned   bias_element  = (_append_bias && !_skip_im2col) ? 1 : 0;
+    const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
 
     // Get parameters from conv_info
     unsigned int stride_x = 0;
@@ -238,12 +239,17 @@
         _memory_group.manage(&_im2col_output);
 
         // Configure and tune im2col
-        _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation);
+        _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
         CLScheduler::get().tune_kernel_static(_im2col_kernel);
 
         // Update GEMM input
         gemm_input_to_use = &_im2col_output;
     }
+    else if(_append_bias)
+    {
+        // Configure add bias kernel
+        _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
+    }
 
     // Create GEMM output tensor
     if(!is_nhwc || _is_quantized)
@@ -281,28 +287,23 @@
         float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
         int   output_multiplier, output_shift;
         quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
-        if(!is_nhwc)
-        {
-            _memory_group.manage(&_tmp_output);
-            gemm_output_staged_to_use = &_tmp_output;
-        }
+
+        _memory_group.manage(&_tmp_output);
+        gemm_output_staged_to_use = &_tmp_output;
+
         _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset);
     }
 
-    if(!is_nhwc)
+    if(!is_nhwc || _is_quantized)
     {
         // Configure and tune Col2Im
         _col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, std::make_pair(conv_w, conv_h));
         CLScheduler::get().tune_kernel_static(_col2im_kernel);
     }
 
-    if(_is_quantized && !is_nhwc)
-    {
-        _tmp_output.allocator()->allocate();
-    }
-
     if(!is_nhwc || _is_quantized)
     {
+        _tmp_output.allocator()->allocate();
         _gemm_output.allocator()->allocate();
     }
 
@@ -348,10 +349,10 @@
     const ITensorInfo *weights_to_use            = weights;
 
     const bool     is_nhwc      = data_layout == DataLayout::NHWC;
-    const bool     skip_im2col  = false;
     const bool     is_quantized = is_data_type_quantized_asymmetric(data_type);
+    const bool     skip_im2col  = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !is_quantized;
     const bool     append_bias  = (biases != nullptr) && (!is_quantized);
-    const unsigned bias_element = (append_bias) ? 1 : 0;
+    const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
 
     ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -410,6 +411,11 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
         gemm_input_to_use = &im2col_reshaped_info;
     }
+    else if(append_bias)
+    {
+        // Validate add bias kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
+    }
 
     // Create GEMM output tensor
     if(!is_nhwc || is_quantized)
@@ -424,25 +430,24 @@
         gemm_output_to_use = &info_gemm;
     }
 
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1, skip_im2col));
 
     if(is_quantized)
     {
         float multiplier = input->quantization_info().scale * weights_to_use->quantization_info().scale / output->quantization_info().scale;
         int   output_multiplier, output_shift;
         quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
-        if(!is_nhwc)
-        {
-            tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
-            tmp_info.set_quantization_info(output->quantization_info());
-            gemm_output_staged_to_use = &tmp_info;
-        }
+
+        tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
+        tmp_info.set_quantization_info(output->quantization_info());
+        gemm_output_staged_to_use = &tmp_info;
+
         // Validate output stage for quantized case
         CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset);
     }
 
     // Validate Col2Im
-    if(!is_nhwc)
+    if(!is_nhwc || is_quantized)
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use,
                                                              output,
@@ -485,8 +490,13 @@
         _mm_gemm.run();
     }
 
+    if(_skip_im2col && _append_bias)
+    {
+        CLScheduler::get().enqueue(_add_bias_kernel);
+    }
+
     // Reshape output matrix
-    if(_data_layout == DataLayout::NCHW)
+    if(_data_layout == DataLayout::NCHW || _is_quantized)
     {
         CLScheduler::get().enqueue(_col2im_kernel, false);
     }
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 842ee73..c2e18a7 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -205,16 +205,17 @@
     const int             k                         = a->dimension(0);
     constexpr int         mult_transpose1xW_width   = 1;
     constexpr int         mult_interleave4x4_height = 1;
-    const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
+    const int             depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
+    const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d);
 
     bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
 
     if(reshape_matrices)
     {
-        TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height), 1, a->data_type());
+        TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()), 1, a->data_type());
         TensorInfo info_b(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width), 1, b->data_type());
 
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &info_b, mult_transpose1xW_width));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output, reshape_matrices, reshape_info));
     }