COMPMID-799 - Use new OpenCL 8-bit dot product instruction

Change-Id: I03d6c6db13bcb565f117725bdab2b68c89a49e21
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122185
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index 5e144d7..da91577 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -190,6 +190,63 @@
 #if MULT_INTERLEAVE4X4_HEIGHT == 1
     for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP))
     {
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        uchar16 a0 = vload16(0, src_addr_a);
+        uchar4  b0 = vload4(0, src_addr_b);
+        uchar4  b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
+        uchar4  b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
+        uchar4  b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
+
+        // Accumulate
+        c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+        c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+        c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+        c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        a0 = vload16(0, src_addr_a + 16);
+        b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
+        b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
+        b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
+        b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
+
+        // Accumulate
+        c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+        c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+        c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+        c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+        c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+        c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+        c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+#else  // ARM_COMPUTE_OPENCL_DOT8_ENABLED
         // Load values from matrix A (interleaved) and matrix B (transposed)
         uchar16 a0 = vload16(0, src_addr_a);
         uchar4  b0 = vload4(0, src_addr_b);
@@ -375,6 +432,7 @@
         c31 += (ushort)a0.sF * b0.s1;
         c32 += (ushort)a0.sF * b0.s2;
         c33 += (ushort)a0.sF * b0.s3;
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
     }
 #endif // MULT_INTERLEAVE4X4_HEIGHT == 1
 
@@ -666,6 +724,13 @@
         uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
 
         {
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+            // Accumulate
+            acc00 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0);
+            acc01 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0);
+            acc02 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0);
+            acc03 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0);
+#else  // ARM_COMPUTE_OPENCL_DOT8_ENABLED
             // Accumulate
             ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
             ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
@@ -691,9 +756,17 @@
             acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
             acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
             acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
         }
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
         {
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+            // Accumulate
+            acc10 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1);
+            acc11 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1);
+            acc12 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1);
+            acc13 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1);
+#else  // ARM_COMPUTE_OPENCL_DOT8_ENABLED
             // Accumulate
             ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
             ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
@@ -719,10 +792,18 @@
             acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
             acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
             acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
         }
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
         {
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+            // Accumulate
+            acc20 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2);
+            acc21 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2);
+            acc22 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2);
+            acc23 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2);
+#else  // ARM_COMPUTE_OPENCL_DOT8_ENABLED
             // Accumulate
             ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
             ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
@@ -748,10 +829,18 @@
             acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
             acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
             acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
         }
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         {
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+            // Accumulate
+            acc30 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3);
+            acc31 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3);
+            acc32 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3);
+            acc33 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3);
+#else  // ARM_COMPUTE_OPENCL_DOT8_ENABLED
             // Accumulate
             ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
             ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
@@ -777,10 +866,18 @@
             acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
             acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
             acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
         }
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
         {
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+            // Accumulate
+            acc40 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4);
+            acc41 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4);
+            acc42 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4);
+            acc43 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4);
+#else  // ARM_COMPUTE_OPENCL_DOT8_ENABLED
             // Accumulate
             ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
             ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
@@ -806,6 +903,7 @@
             acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
             acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
             acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
         }
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
     }