COMPMID-474 - Add support for QS8/QS16 DirectConvolution CL

Change-Id: I537e4acbc02c8d880ff8630ea62223e0f1a1dda3
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/82875
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp
index 1c855e4..6bc82a2 100644
--- a/src/core/CL/CLHelpers.cpp
+++ b/src/core/CL/CLHelpers.cpp
@@ -86,6 +86,8 @@
             return "uint";
         case DataType::S32:
             return "int";
+        case DataType::QS32:
+            return "qs32";
         case DataType::U64:
             return "ulong";
         case DataType::S64:
@@ -134,6 +136,8 @@
             return "char";
         case DataType::QS16:
             return "short";
+        case DataType::QS32:
+            return "int";
         default:
             return get_cl_type_from_data_type(dt);
     }
diff --git a/src/core/CL/cl_kernels/direct_convolution1x1.cl b/src/core/CL/cl_kernels/direct_convolution1x1.cl
index ec0551b..66c618e 100644
--- a/src/core/CL/cl_kernels/direct_convolution1x1.cl
+++ b/src/core/CL/cl_kernels/direct_convolution1x1.cl
@@ -23,6 +23,23 @@
  */
 #include "helpers.h"
 
+#if defined(FIXED_POINT_POSITION)
+#include "fixed_point.h"
+
+#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE_PROMOTED, 8)
+#define MUL_OP(a, b) MUL_SAT_OP_EXPAND(CONVERT((a), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), CONVERT((b), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), DATA_TYPE_PROMOTED, 8, FIXED_POINT_POSITION)
+
+// There is no need to have a larger intermediate type for qs32 because all the arguments are already promoted
+MULQ_SAT_IMPL(qs32x8, qs32x8)
+
+#else /* FIXED_POINT_POSITION */
+
+#define ADD_OP(a, b) ((a) + (b))
+#define MUL_OP(a, b) ((a) * (b))
+#define CONVERT_SAT(a, b) ((a))
+
+#endif /* FIXED_POINT_POSITION */
+
 #if STRIDE_X == 3
 #define INPUT_PIXEL_STR(data_size) extract_input_stride3_##data_size
 #define INPUT_PIXEL(data_size) INPUT_PIXEL_STR(data_size)
@@ -165,7 +182,7 @@
     Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
 #endif /* defined(HAS_BIAS) */
 
-    VEC_DATA_TYPE(DATA_TYPE, 8)
+    VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
     pixels = 0;
 
     const uint z_index = get_global_id(2);
@@ -177,15 +194,15 @@
         DATA_TYPE weight = *(__global DATA_TYPE *)weights.ptr;
         VEC_DATA_TYPE(DATA_TYPE, 8)
         input_pixel = INPUT_PIXEL(DATA_SIZE)((__global DATA_TYPE *)src.ptr);
-        pixels += weight * input_pixel;
+        pixels      = ADD_OP(pixels, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))weight, input_pixel));
         src.ptr += src_stride_z;
         weights.ptr += weights_stride_z;
     }
 
 #ifdef HAS_BIAS
-    pixels += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, z_index)));
+    pixels = ADD_OP(pixels, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, z_index))));
 #endif /* defined(HAS_BIAS) */
 
-    vstore8(pixels, 0, (__global DATA_TYPE *)dst.ptr);
+    vstore8(CONVERT_SAT(pixels, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
 }
 #endif // defined(DATA_TYPE) && defined(DATA_SIZE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/direct_convolution3x3.cl b/src/core/CL/cl_kernels/direct_convolution3x3.cl
index 51886ef..4da7c39 100644
--- a/src/core/CL/cl_kernels/direct_convolution3x3.cl
+++ b/src/core/CL/cl_kernels/direct_convolution3x3.cl
@@ -23,6 +23,23 @@
  */
 #include "helpers.h"
 
+#if defined(FIXED_POINT_POSITION)
+#include "fixed_point.h"
+
+#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE_PROMOTED, 8)
+#define MUL_OP(a, b) MUL_SAT_OP_EXPAND(CONVERT((a), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), CONVERT((b), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), DATA_TYPE_PROMOTED, 8, FIXED_POINT_POSITION)
+
+// There is no need to have a larger intermediate type for qs32 because all the arguments are already promoted
+MULQ_SAT_IMPL(qs32x8, qs32x8)
+
+#else /* FIXED_POINT_POSITION */
+
+#define ADD_OP(a, b) ((a) + (b))
+#define MUL_OP(a, b) ((a) * (b))
+#define CONVERT_SAT(a, b) ((a))
+
+#endif /* FIXED_POINT_POSITION */
+
 #if STRIDE_X == 1
 #define CONVOLUTION1x3(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)
 #elif STRIDE_X == 2 /* STRIDE_X == 1 */
@@ -31,31 +48,31 @@
 #error "STRIDE_X larger than 2 is not supported"
 #endif /* STRIDE_X == 2 */
 
-#define CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)                                                               \
-    ({                                                                                                                          \
-        VEC_DATA_TYPE(DATA_TYPE, 4)                                                                                             \
-        weights_values0 = vload4(0, weights_row_ptr);                                                                           \
-        VEC_DATA_TYPE(DATA_TYPE, 8)                                                                                             \
-        src0 = vload8(0, src_row_ptr);                                                                                          \
-        VEC_DATA_TYPE(DATA_TYPE, 2)                                                                                             \
-        src1 = vload2(0, src_row_ptr + 8);                                                                                      \
+#define CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)                                                                                  \
+    ({                                                                                                                                             \
+        VEC_DATA_TYPE(DATA_TYPE, 4)                                                                                                                \
+        weights_values0 = vload4(0, weights_row_ptr);                                                                                              \
+        VEC_DATA_TYPE(DATA_TYPE, 8)                                                                                                                \
+        src0 = vload8(0, src_row_ptr);                                                                                                             \
+        VEC_DATA_TYPE(DATA_TYPE, 2)                                                                                                                \
+        src1 = vload2(0, src_row_ptr + 8);                                                                                                         \
         \
-        acc += src0 * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0;                                                          \
-        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1; \
-        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+        acc = ADD_OP(acc, MUL_OP(src0, (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0));                                                          \
+        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1)); \
+        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2)); \
     })
 
-#define CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr)                                                            \
-    ({                                                                                                                       \
-        VEC_DATA_TYPE(DATA_TYPE, 4)                                                                                          \
-        weights_values0 = vload4(0, weights_row_ptr);                                                                        \
-        VEC_DATA_TYPE(DATA_TYPE, 16)                                                                                         \
-        src0           = vload16(0, src_row_ptr);                                                                            \
-        DATA_TYPE src1 = *(src_row_ptr + 16);                                                                                \
+#define CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr)                                                                               \
+    ({                                                                                                                                          \
+        VEC_DATA_TYPE(DATA_TYPE, 4)                                                                                                             \
+        weights_values0 = vload4(0, weights_row_ptr);                                                                                           \
+        VEC_DATA_TYPE(DATA_TYPE, 16)                                                                                                            \
+        src0           = vload16(0, src_row_ptr);                                                                                               \
+        DATA_TYPE src1 = *(src_row_ptr + 16);                                                                                                   \
         \
-        acc += src0.even * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0;                                                  \
-        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1;      \
-        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+        acc = ADD_OP(acc, MUL_OP(src0.even, (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0));                                                  \
+        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1));      \
+        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2)); \
     })
 
 /** This kernel performs a direct convolution to convolve the low three dimensions.
@@ -108,7 +125,7 @@
     Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
     Tensor3D dst     = CONVERT_TO_TENSOR3D_STRUCT(dst);
 
-    VEC_DATA_TYPE(DATA_TYPE, 8)
+    VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
     pixels0 = 0;
 
     __global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
@@ -130,9 +147,9 @@
 #ifdef HAS_BIAS
     Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
 
-    pixels0 += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index)));
+    pixels0 = ADD_OP(pixels0, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index))));
 #endif /* defined(HAS_BIAS) */
 
-    vstore8(pixels0, 0, (__global DATA_TYPE *)dst.ptr);
+    vstore8(CONVERT_SAT(pixels0, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
 }
 #endif // defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/fixed_point.h b/src/core/CL/cl_kernels/fixed_point.h
index 7038d40..d35a46f 100644
--- a/src/core/CL/cl_kernels/fixed_point.h
+++ b/src/core/CL/cl_kernels/fixed_point.h
@@ -168,6 +168,11 @@
 ADDQ_SAT_IMPL(qs16x4)
 ADDQ_SAT_IMPL(qs16x8)
 ADDQ_SAT_IMPL(qs16x16)
+ADDQ_SAT_IMPL(qs32x1)
+ADDQ_SAT_IMPL(qs32x2)
+ADDQ_SAT_IMPL(qs32x4)
+ADDQ_SAT_IMPL(qs32x8)
+ADDQ_SAT_IMPL(qs32x16)
 
 #define ADD_SAT_OP_EXPAND_STR(a, b, type, size) add_sat_##type##x##size((a), (b))
 #define ADD_SAT_OP_EXPAND(a, b, type, size) ADD_SAT_OP_EXPAND_STR(a, b, type, size)
@@ -213,6 +218,8 @@
         return CONVERT((res >> (itype)fixed_point_position), type);                    \
     }
 
+MULQ_IMPL(qs8x8, qs16x8)
+MULQ_IMPL(qs16x8, qs32x8)
 MULQ_IMPL(qs8x16, qs16x16)
 MULQ_IMPL(qs16x16, qs32x16)
 
@@ -234,8 +241,9 @@
         return CONVERT_SAT((res >> (itype)fixed_point_position), type);                       \
     }
 
-MULQ_SAT_IMPL(qs8x16, qs16x16)
+MULQ_SAT_IMPL(qs8x8, qs16x8)
 MULQ_SAT_IMPL(qs16x8, qs32x8)
+MULQ_SAT_IMPL(qs8x16, qs16x16)
 MULQ_SAT_IMPL(qs16x16, qs32x16)
 
 #define MUL_SAT_OP_EXPAND_STR(a, b, type, size, position) mul_sat_##type##x##size((a), (b), (position))
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 5f14d16..c5fdb77 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -50,7 +50,7 @@
 
 void CLDirectConvolutionLayerKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
     ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) != weights->info()->dimension(1),
                              "Only kernel sizes 1x1 and 3x3 are supported");
@@ -102,12 +102,32 @@
     std::stringstream     kernel_name;
     std::set<std::string> options;
     kernel_name << "direct_convolution" << kernel_size << "x" << kernel_size;
+    DataType promoted_type = input->info()->data_type();
 
     options.emplace("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
     options.emplace("-DDATA_SIZE=" + get_data_size_from_data_type(input->info()->data_type()));
     options.emplace("-DWEIGHTS_DEPTH=" + support::cpp11::to_string(_weights->info()->dimension(2)));
     options.emplace("-DSTRIDE_X=" + support::cpp11::to_string(_conv_stride_x));
 
+    if(is_data_type_fixed_point(input->info()->data_type()))
+    {
+        options.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
+
+        switch(input->info()->data_type())
+        {
+            case DataType::QS8:
+                promoted_type = DataType::QS16;
+                break;
+            case DataType::QS16:
+                promoted_type = DataType::QS32;
+                break;
+            default:
+                ARM_COMPUTE_ERROR("Datatype not supported");
+        }
+    }
+
+    options.emplace("-DDATA_TYPE_PROMOTED=" + get_cl_type_from_data_type(promoted_type));
+
     if(_biases != nullptr)
     {
         options.emplace("-DHAS_BIAS");
diff --git a/tests/validation_new/CL/DirectConvolutionLayer.cpp b/tests/validation_new/CL/DirectConvolutionLayer.cpp
index 1a7cd6b..d82f535 100644
--- a/tests/validation_new/CL/DirectConvolutionLayer.cpp
+++ b/tests/validation_new/CL/DirectConvolutionLayer.cpp
@@ -46,6 +46,9 @@
 constexpr AbsoluteTolerance<float> tolerance_fp16(0.1f);   /**< Tolerance for floating point tests */
 constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance for floating point tests */
 
+constexpr AbsoluteTolerance<int8_t>  tolerance_qs8(0);  /**< Tolerance for fixed point tests */
+constexpr AbsoluteTolerance<int16_t> tolerance_qs16(0); /**< Tolerance for fixed point tests */
+
 /** Direct convolution data set. */
 const auto data = combine(datasets::SmallDirectConvolutionShapes(),
                           combine(framework::dataset::make("StrideX", 1, 3),
@@ -85,6 +88,29 @@
 TEST_SUITE_END()
 TEST_SUITE_END()
 
+template <typename T>
+using CLDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFixedPointFixture<CLTensor, CLAccessor, CLDirectConvolutionLayer, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QS8)
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS8)),
+                                                                                                                    framework::dataset::make("FractionalBits", 2, 7)))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_qs8);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(QS16)
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS16)),
+                                                                                                                     framework::dataset::make("FractionalBits", 2, 15)))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_qs16);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
 TEST_SUITE_END()
 TEST_SUITE_END()
 } // namespace validation