COMPMID-3535: 9x9 Direct convolution support for CL and NHWC

* Supported strides 1 and 2

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I4b9f087c0c328234159b2d1eacc2e465b3bb3c54
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3603
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/direct_convolution_quantized.cl b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
index ed1b7cf..8237fe1 100644
--- a/src/core/CL/cl_kernels/direct_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
@@ -33,7 +33,113 @@
 
 #if defined(DATA_LAYOUT_NHWC)
 
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+
+#if STRIDE_X == 1
+#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr)
+#elif STRIDE_X == 2
+#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr)
+#else /* STRIDE_X not equals 1 or 2 */
+#error "STRIDE_X larger than 2 is not supported"
+#endif /* STRIDE_X */
+
+#define CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr)                          \
+    ({                                                                             \
+        int8 weights_values0 = 0;                                                  \
+        int  weights_value1  = 0;                                                  \
+        weights_values0.s0   = convert_int(*(weights_ptr + 0 * weights_stride_y)); \
+        weights_values0.s1   = convert_int(*(weights_ptr + 1 * weights_stride_y)); \
+        weights_values0.s2   = convert_int(*(weights_ptr + 2 * weights_stride_y)); \
+        weights_values0.s3   = convert_int(*(weights_ptr + 3 * weights_stride_y)); \
+        weights_values0.s4   = convert_int(*(weights_ptr + 4 * weights_stride_y)); \
+        weights_values0.s5   = convert_int(*(weights_ptr + 5 * weights_stride_y)); \
+        weights_values0.s6   = convert_int(*(weights_ptr + 6 * weights_stride_y)); \
+        weights_values0.s7   = convert_int(*(weights_ptr + 7 * weights_stride_y)); \
+        weights_value1       = convert_int(*(weights_ptr + 8 * weights_stride_y)); \
+        \
+        int8 src0 = 0;                                                             \
+        int8 src1 = 0;                                                             \
+        src0.s0   = convert_int(*(src_ptr + 0 * weights_stride_y));                \
+        src0.s1   = convert_int(*(src_ptr + 1 * weights_stride_y));                \
+        src0.s2   = convert_int(*(src_ptr + 2 * weights_stride_y));                \
+        src0.s3   = convert_int(*(src_ptr + 3 * weights_stride_y));                \
+        src0.s4   = convert_int(*(src_ptr + 4 * weights_stride_y));                \
+        src0.s5   = convert_int(*(src_ptr + 5 * weights_stride_y));                \
+        src0.s6   = convert_int(*(src_ptr + 6 * weights_stride_y));                \
+        src0.s7   = convert_int(*(src_ptr + 7 * weights_stride_y));                \
+        src1.s0   = convert_int(*(src_ptr + 8 * weights_stride_y));                \
+        src1.s1   = convert_int(*(src_ptr + 9 * weights_stride_y));                \
+        src1.s2   = convert_int(*(src_ptr + 10 * weights_stride_y));               \
+        src1.s3   = convert_int(*(src_ptr + 11 * weights_stride_y));               \
+        src1.s4   = convert_int(*(src_ptr + 12 * weights_stride_y));               \
+        src1.s5   = convert_int(*(src_ptr + 13 * weights_stride_y));               \
+        src1.s6   = convert_int(*(src_ptr + 14 * weights_stride_y));               \
+        src1.s7   = convert_int(*(src_ptr + 15 * weights_stride_y));               \
+        \
+        acc += src0 * (int8)weights_values0.s0;                                    \
+        acc += (int8)(src0.s1234, src0.s567, src1.s0) * (int8)weights_values0.s1;  \
+        acc += (int8)(src0.s234, src0.s567, src1.s01) * (int8)weights_values0.s2;  \
+        acc += (int8)(src0.s345, src0.s67, src1.s012) * (int8)weights_values0.s3;  \
+        acc += (int8)(src0.s4567, src1.s0123) * (int8)weights_values0.s4;          \
+        acc += (int8)(src0.s567, src1.s0123, src1.s4) * (int8)weights_values0.s5;  \
+        acc += (int8)(src0.s67, src1.s012, src1.s345) * (int8)weights_values0.s6;  \
+        acc += (int8)(src0.s7, src1.s0123, src1.s456) * (int8)weights_values0.s7;  \
+        acc += src1 * (int8)weights_value1;                                        \
+    })
+
+#define CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr)                          \
+    ({                                                                             \
+        int8 weights_values0 = 0;                                                  \
+        int  weights_value1  = 0;                                                  \
+        weights_values0.s0   = convert_int(*(weights_ptr + 0 * weights_stride_y)); \
+        weights_values0.s1   = convert_int(*(weights_ptr + 1 * weights_stride_y)); \
+        weights_values0.s2   = convert_int(*(weights_ptr + 2 * weights_stride_y)); \
+        weights_values0.s3   = convert_int(*(weights_ptr + 3 * weights_stride_y)); \
+        weights_values0.s4   = convert_int(*(weights_ptr + 4 * weights_stride_y)); \
+        weights_values0.s5   = convert_int(*(weights_ptr + 5 * weights_stride_y)); \
+        weights_values0.s6   = convert_int(*(weights_ptr + 6 * weights_stride_y)); \
+        weights_values0.s7   = convert_int(*(weights_ptr + 7 * weights_stride_y)); \
+        weights_value1       = convert_int(*(weights_ptr + 8 * weights_stride_y)); \
+        \
+        int16 src0 = 0;                                                            \
+        int8  src1 = 0;                                                            \
+        src0.s0    = convert_int(*(src_ptr + 0 * weights_stride_y));               \
+        src0.s1    = convert_int(*(src_ptr + 1 * weights_stride_y));               \
+        src0.s2    = convert_int(*(src_ptr + 2 * weights_stride_y));               \
+        src0.s3    = convert_int(*(src_ptr + 3 * weights_stride_y));               \
+        src0.s4    = convert_int(*(src_ptr + 4 * weights_stride_y));               \
+        src0.s5    = convert_int(*(src_ptr + 5 * weights_stride_y));               \
+        src0.s6    = convert_int(*(src_ptr + 6 * weights_stride_y));               \
+        src0.s7    = convert_int(*(src_ptr + 7 * weights_stride_y));               \
+        src0.s8    = convert_int(*(src_ptr + 8 * weights_stride_y));               \
+        src0.s9    = convert_int(*(src_ptr + 9 * weights_stride_y));               \
+        src0.sA    = convert_int(*(src_ptr + 10 * weights_stride_y));              \
+        src0.sB    = convert_int(*(src_ptr + 11 * weights_stride_y));              \
+        src0.sC    = convert_int(*(src_ptr + 12 * weights_stride_y));              \
+        src0.sD    = convert_int(*(src_ptr + 13 * weights_stride_y));              \
+        src0.sE    = convert_int(*(src_ptr + 14 * weights_stride_y));              \
+        src0.sF    = convert_int(*(src_ptr + 15 * weights_stride_y));              \
+        src1.s0    = convert_int(*(src_ptr + 16 * weights_stride_y));              \
+        src1.s1    = convert_int(*(src_ptr + 17 * weights_stride_y));              \
+        src1.s2    = convert_int(*(src_ptr + 18 * weights_stride_y));              \
+        src1.s3    = convert_int(*(src_ptr + 19 * weights_stride_y));              \
+        src1.s4    = convert_int(*(src_ptr + 20 * weights_stride_y));              \
+        src1.s5    = convert_int(*(src_ptr + 21 * weights_stride_y));              \
+        src1.s6    = convert_int(*(src_ptr + 22 * weights_stride_y));              \
+        src1.s7    = convert_int(*(src_ptr + 23 * weights_stride_y));              \
+        \
+        acc += src0.s02468ACE * (int8)weights_values0.s0;                          \
+        acc += (int8)(src0.s1357, src0.s9BDF) * (int8)weights_values0.s1;          \
+        acc += (int8)(src0.s2468, src0.sACE, src1.s0) * (int8)weights_values0.s2;  \
+        acc += (int8)(src0.s3579, src0.sBDF, src1.s1) * (int8)weights_values0.s3;  \
+        acc += (int8)(src0.s468A, src0.sCE, src1.s02) * (int8)weights_values0.s4;  \
+        acc += (int8)(src0.s579, src0.sBDF, src1.s13) * (int8)weights_values0.s5;  \
+        acc += (int8)(src0.s68A, src0.sCE, src1.s024) * (int8)weights_values0.s6;  \
+        acc += (int8)(src0.s79B, src0.sDF, src1.s135) * (int8)weights_values0.s7;  \
+        acc += (int8)(src0.s8AC, src0.sE, src1.s0246) * (int8)weights_value1;      \
+    })
+
+#elif KERNEL_SIZE == 5
 
 #if STRIDE_X == 1
 #define CONVOLUTION1x5(acc, src_ptr, weights_ptr) CONVOLUTION1x5_STRIDE1(acc, src_ptr, weights_ptr)
@@ -331,7 +437,37 @@
 
     for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
     {
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+        if(y_coord < 0)
+        {
+            const int start_z = -y_coord;
+            for(int i = start_z; i < 9; ++i)
+            {
+                CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z));
+            }
+        }
+        else if(y_coord > (SRC_HEIGHT - 9))
+        {
+            // Avoid loading rows beyond the input height
+            const int end_z = SRC_HEIGHT - y_coord;
+            for(int i = 0; i < end_z; ++i)
+            {
+                CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z));
+            }
+        }
+        else
+        {
+            CONVOLUTION1x9(values0, src_addr, weights_addr);
+            CONVOLUTION1x9(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+            CONVOLUTION1x9(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+            CONVOLUTION1x9(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+            CONVOLUTION1x9(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+            CONVOLUTION1x9(values0, (src_addr + 5 * (int)src_stride_z), (weights_addr + 5 * (int)weights_stride_z));
+            CONVOLUTION1x9(values0, (src_addr + 6 * (int)src_stride_z), (weights_addr + 6 * (int)weights_stride_z));
+            CONVOLUTION1x9(values0, (src_addr + 7 * (int)src_stride_z), (weights_addr + 7 * (int)weights_stride_z));
+            CONVOLUTION1x9(values0, (src_addr + 8 * (int)src_stride_z), (weights_addr + 8 * (int)weights_stride_z));
+        }
+#elif KERNEL_SIZE == 5
 #if(PAD_TOP == 1) || (PAD_BOTTM == 1)
         if(y_coord < 0) // special case Z = -1 doesn't exists
         {