Optimize Transposed Convolution for CL backend (Quantized)

This patch optimizes transposed convolution for QASYMM and QASYMM8_SIGNED types, by extending the transposed convolution kernel written for FP32/16.

Resolves: COMPMID-5723
Change-Id: Iab8f09231938adb949c506fd915ed45b885e5c7c
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8792
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/nhwc/transposed_convolution.cl b/src/core/CL/cl_kernels/nhwc/transposed_convolution.cl
index c01a44f..1ca282c 100644
--- a/src/core/CL/cl_kernels/nhwc/transposed_convolution.cl
+++ b/src/core/CL/cl_kernels/nhwc/transposed_convolution.cl
@@ -29,7 +29,7 @@
 /** OpenCL kernel to compute the transposed convolution.
  *
  * @note Data layout supported: NHWC
- * @note Data type supported: F32/F16
+ * @note Data type supported: F32/F16/QASYMM8/QASYMM8_SIGNED
  * @note The transposed convolution padding (left and top) must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (e.g. -DPAD_LEFT=2, -DPAD_TOP=2)
  * @note The transposed convolution strides must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y (e.g. -DSTRIDE_X=2, -DSTRIDE_Y=2)
  * @note The spatial dimensions of the weights must be passed at compile time using -DWEI_WIDTH and -DWEI_HEIGHT (e.g. -DWEI_WIDTH=9, -DWEI_HEIGHT=9)
@@ -43,15 +43,26 @@
  * @note The data type of the source tensor must be passed at compile time using -DSRC_DATA_TYPE (e.g. -DSRC_DATA_TYPE=float)
  * @note The data type of the weights tensor must be passed at compile time using -DWEI_DATA_TYPE (e.g. -DWEI_DATA_TYPE=float)
  * @note The data type of the destination tensor must be passed at compile time using -DDST_DATA_TYPE (e.g. -DDST_DATA_TYPE=float)
+ * @note The data type of the destination tensor must be passed at compile time using -DBIA_DATA_TYPE (e.g. -DBIA_DATA_TYPE=float)
  * @note The data type of the accumulators must be passed at compile time using -DACC_DATA_TYPE (e.g. -DACC_DATA_TYPE=float)
  * @note The number of M0 rows (width*height) to process must be passed at compile time using -DM0 (e.g. -DM0=2)
  * @note The number of N0 output channels to process must be passed at compile time using -DN0 (e.g. -DN0=2)
  * @note The number of K0 inner accumulations must be passed at compile time using -DK0 (e.g. -DK0=2)
  * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_N0 (e.g. -DPARTIAL_N0=1)
+ * @note If bias exists, the compile time argument -DHAS_BIAS should be passed
  * @note Only the following configurations of M0, N0 and K0 are currently supported:
  *  - M0 = 1
  *  - N0 = 1
- *  - K0 = 2, 3, 4, 8
+ *  - K0 = 2, 3, 4, 8, 16
+ *
+ * @note In case of QASYMM8/QASYMM8_SIGNED, the following extra information must be passed at compile time:
+ * - -DIS_QUANTIZED
+ * - The destination quantization multiplier e.g. -DDST_MULTIPLIER=1234
+ * - The destination quantization shift e.g. -DDST_SHIFT=4
+ * - The destination offset e.g. -DDST_OFFSET=4
+ * - The source offset e.g. -DSRC_OFFSET=4
+ * - The weights offset e.g. -DWEI_OFFSET=4
+ * - The quantized zero value e.g. -DZERO_VALUE=4
  *
  *
  * @param[in]  src_ptr                           Pointer to the source tensor. Supported data type: F16/F32
@@ -108,6 +119,12 @@
 #define _IDST_CHANNELS DST_CHANNELS
 #define _IY_MULTIPLIER (_IWEI_WIDTH * _IWEI_HEIGHT)
 
+#if defined(IS_QUANTIZED)
+#define _IOUTPUT_TILE cq
+#else // defined(IS_QUANTIZED)
+#define _IOUTPUT_TILE c
+#endif // defined(IS_QUANTIZED)
+
     const int cout = GET_SPATIAL_IDX(0, N0, PARTIAL_N0); // OFM
     const int mout = GET_SPATIAL_IDX(1, M0, 0);          // WIDTH x HEIGHT
     const int bout = GET_SPATIAL_IDX(2, 1, 0);           // BATCH SIZE IDX
@@ -144,7 +161,7 @@
     {
         for(int xk = x_start, xi_step = 0; xk >= 0; xk -= STRIDE_X, ++xi_step)
         {
-            int weights_y = cout * _IY_MULTIPLIER + yk * _IWEI_WIDTH + xk;
+            const int weights_y = cout * _IY_MULTIPLIER + yk * _IWEI_WIDTH + xk;
 
             TILE(int, 1, M0, my);
 
@@ -169,12 +186,12 @@
                 // Initialize tiles
                 LOOP_UNROLLING(int, i, 0, 1, M0,
                 {
-                    a[i].v = 0.f;
+                    a[i].v = ZERO_VALUE;
                 })
 
                 LOOP_UNROLLING(int, i, 0, 1, N0,
                 {
-                    b[i].v = 0.f;
+                    b[i].v = ZERO_VALUE;
                 })
 
                 // Load tile from the src tensor
@@ -185,6 +202,12 @@
 
                 // Compute the matrix multiplication between two tiles
                 T_MMUL(SRC_DATA_TYPE, WEI_DATA_TYPE, ACC_DATA_TYPE, M0, N0, K0, NT, T, a, b, c);
+
+#if defined(IS_QUANTIZED)
+                // Apply the offset correction (correction usually needed for asymmetric quantized computation)
+                // The computation is not performed if both SRC_OFFSET and WEI_OFFSET are zero
+                T_OFFSET_CORRECTION(ACC_DATA_TYPE, M0, N0, K0, SRC_OFFSET, WEI_OFFSET, a, b, c);
+#endif // defined(IS_QUANTIZED)
             }
 
             // This #if directive should be removed in case of dynamic tensor support
@@ -198,7 +221,7 @@
                 // Initialize tiles
                 LOOP_UNROLLING(int, i, 0, 1, M0,
                 {
-                    a[i].v = 0.f;
+                    a[i].v = ZERO_VALUE;
                 })
 
                 // Load tile from the src tensor
@@ -211,11 +234,23 @@
 
                 // Compute the matrix multiplication between two tiles
                 T_MMUL(SRC_DATA_TYPE, WEI_DATA_TYPE, ACC_DATA_TYPE, M0, N0, 1, NT, T, a, b, c);
+
+#if defined(IS_QUANTIZED)
+                // Apply the offset correction (correction usually needed for asymmetric quantized computation)
+                // The computation is not performed if both SRC_OFFSET and WEI_OFFSET are zero
+                T_OFFSET_CORRECTION(ACC_DATA_TYPE, M0, N0, 1, SRC_OFFSET, WEI_OFFSET, a, b, c);
+#endif // defined(IS_QUANTIZED)
             }
 #endif // defined(LEFTOVER_LOOP)
         }
     }
 
+#if defined(IS_QUANTIZED)
+    const int total_pixels = floor((1 + y_start / (float)STRIDE_Y)) * floor(1 + x_start / (float)STRIDE_X);
+
+    T_ADD_CONSTANT(ACC_DATA_TYPE, M0, N0, c, (total_pixels * _ISRC_CHANNELS * SRC_OFFSET * WEI_OFFSET), c);
+#endif // defined(IS_QUANTIZED)
+
 #if defined(HAS_BIAS)
     TILE(BIA_DATA_TYPE, 1, N0, bias0);
 
@@ -226,6 +261,14 @@
 
 #endif // HAS_BIAS
 
+#if defined(IS_QUANTIZED)
+
+    TILE(DST_DATA_TYPE, M0, N0, cq);
+
+    // Quantize the tile
+    T_QUANTIZE8_ASYMMETRIC(ACC_DATA_TYPE, DST_DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
+#endif // defined(IS_QUANTIZED)
+
     TILE(uint, M0, 1, dst_indirect_y);
 
     // Calculate the destination indirect Y
@@ -238,7 +281,7 @@
     bool x_cond = PARTIAL_N0 != 0 && get_global_id(0) == 0;
 
     // Store the tile in reverse order so the invalid values are overwritten with the valid ones
-    T_STORE_INDIRECT_WIDTH_SELECT(DST_DATA_TYPE, M0, N0, PARTIAL_N0, DST_TENSOR_TYPE, dst, cout, dst_stride_y, x_cond, c, dst_indirect_y);
+    T_STORE_INDIRECT_WIDTH_SELECT(DST_DATA_TYPE, M0, N0, PARTIAL_N0, DST_TENSOR_TYPE, dst, cout, dst_stride_y, x_cond, _IOUTPUT_TILE, dst_indirect_y);
 
 #undef _IWEI_WIDTH
 #undef _IWEI_HEIGHT
diff --git a/src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp b/src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp
index 16c6ad9..714ca8e 100644
--- a/src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp
+++ b/src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp
@@ -30,6 +30,8 @@
 #include "src/core/helpers/WindowHelpers.h"
 #include "support/Cast.h"
 
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+
 namespace arm_compute
 {
 namespace opencl
@@ -42,7 +44,7 @@
                           const PadStrideInfo &deconv_info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QASYMM8_SIGNED, DataType::QASYMM8);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NHWC);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(weights, DataLayout::NHWC);
@@ -57,7 +59,15 @@
 
     if(biases != nullptr)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
+        if(is_data_type_quantized_asymmetric(input->data_type()))
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
+        }
+
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(channel_idx) != weights->dimension(batch_idx),
                                         "Biases size and number of dst feature maps should match");
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "Biases should be one dimensional");
@@ -127,12 +137,12 @@
     const std::string kernel_name = "transposed_convolution_nhwc";
     CLBuildOptions    build_options;
 
-    const DataType input_data_type = input->data_type();        // Fp32 or Fp16 only
-    const auto     strides         = deconv_info.stride();
+    const DataType    input_data_type = input->data_type();
+    const PaddingInfo strides         = deconv_info.stride();
 
     const unsigned int n0               = 1;
     const unsigned int m0               = 1;
-    const unsigned int k0               = adjust_vec_size(input_data_type == DataType::F32 ? 4 : 8, input_channels);
+    const unsigned int k0               = adjust_vec_size(16 / input->element_size(), input_channels);
     const unsigned int partial_store_n0 = output_channels % n0;
 
     if(biases != nullptr)
@@ -167,7 +177,36 @@
     build_options.add_option("-DK0=" + support::cpp11::to_string(k0));
     build_options.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
     build_options.add_option_if((input_channels % k0) != 0, "-DLEFTOVER_LOOP");
-    build_options.add_option("-DACC_DATA_TYPE=" + get_cl_type_from_data_type(input_data_type));
+
+    if(is_data_type_quantized(output_data_type))
+    {
+        const UniformQuantizationInfo iqinfo = input->quantization_info().uniform();
+        const UniformQuantizationInfo wqinfo = weights->quantization_info().uniform();
+        const UniformQuantizationInfo oqinfo = output->quantization_info().uniform();
+
+        PixelValue zero_value = PixelValue(0, input->data_type(), input->quantization_info());
+        int        zero_value_s32;
+        zero_value.get(zero_value_s32);
+
+        float multiplier        = iqinfo.scale * wqinfo.scale / oqinfo.scale;
+        int   output_multiplier = 0;
+        int   output_shift      = 0;
+
+        quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
+        build_options.add_option("-DIS_QUANTIZED");
+        build_options.add_option("-DDST_MULTIPLIER=" + support::cpp11::to_string(output_multiplier));
+        build_options.add_option("-DDST_SHIFT=" + support::cpp11::to_string(output_shift));
+        build_options.add_option("-DSRC_OFFSET=" + support::cpp11::to_string(-iqinfo.offset));
+        build_options.add_option("-DWEI_OFFSET=" + support::cpp11::to_string(-wqinfo.offset));
+        build_options.add_option("-DDST_OFFSET=" + support::cpp11::to_string(oqinfo.offset));
+        build_options.add_option("-DZERO_VALUE=" + support::cpp11::to_string(zero_value_s32));
+        build_options.add_option("-DACC_DATA_TYPE=" + get_cl_type_from_data_type(DataType::S32));
+    }
+    else
+    {
+        build_options.add_option("-DACC_DATA_TYPE=" + get_cl_type_from_data_type(input_data_type));
+        build_options.add_option("-DZERO_VALUE=" + support::cpp11::to_string(0));
+    }
 
     if(compile_context.get_ddk_version() >= 30)
     {
diff --git a/src/gpu/cl/operators/ClTransposedConvolution.h b/src/gpu/cl/operators/ClTransposedConvolution.h
index bc04387..58ebc68 100644
--- a/src/gpu/cl/operators/ClTransposedConvolution.h
+++ b/src/gpu/cl/operators/ClTransposedConvolution.h
@@ -57,11 +57,11 @@
      *
      * @param[in]  compile_context The compile context to be used.
      * @param[in]  input           Input tensor info with dimensions [IFM, width, height, batch]
-     *                             Data types supported: F16/F32.
+     *                             Data types supported: F16/F32/QASYMM8/QASYMM8_SIGNED.
      * @param[in]  weights         Weight tensor info with dimensions [IFM, width, height, OFM].
      *                             Data type supported: Same as @p input
      * @param[in]  biases          (Optional) Biases tensor info. Biases are 1D tensor with dimension [OFM].
-     *                             Data type supported: Should match @p input data type
+     *                             Data type supported: Should match @p input data type if floating point, otherwise S32.
      * @param[out] output          Output tensor info with dimensions [OFM, width, height, batch]
      *                             The 1st dimension must be equal to the 4th dimension of the @p weights tensor.
      *                             Data types supported: Same as @p input.
diff --git a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp
index a4db6d7..ea7f3e7 100644
--- a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp
@@ -141,16 +141,16 @@
 {
     ARM_COMPUTE_UNUSED(output, bias, weights_info);
 
-    if(input->data_layout() == DataLayout::NHWC && (input->data_type() == DataType::F32 || input->data_type() == DataType::F16))
-    {
-        return DeconvolutionMethod::DIRECT;
-    }
-
     if(is_data_type_quantized_per_channel(weights->data_type()))
     {
         return DeconvolutionMethod::UPSCALE_CONV2D;
     }
 
+    if(input->data_layout() == DataLayout::NHWC)
+    {
+        return DeconvolutionMethod::DIRECT;
+    }
+
     const DataLayout data_layout = input->data_layout();
 
     const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);