COMPMID-1533 Implementing CLDepthWiseConvolutionLayer with FP16 (NHWC)

Change-Id: I46965aeb1fffba8cbf083cab7284c549b0e94d00
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145334
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
diff --git a/src/core/CL/cl_kernels/depthwise_convolution.cl b/src/core/CL/cl_kernels/depthwise_convolution.cl
index 77a76b6..23237da 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution.cl
@@ -996,13 +996,18 @@
 }
 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER)
 
-#if defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)
+#if defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT) && defined(DATA_TYPE)
 
-#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE)
+#if DATA_TYPE != float || DATA_TYPE != half
+#error "Unsupported data type"
+#endif // DATA_TYPE != float || DATA_TYPE != half
+
+#define VEC_FLOAT VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
 
 #if defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y)
 /** This function computes the depthwise convolution for NHWC data layout when the stride along the width or height is not 1.
  *
+ * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=float
  * @note The number of elements read per thread must be passed at compile time using -DVEC_SIZE (e.g. -DVEC_SIZE=2)
  * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
  * @note The convolution pad top must be passed at compile time using -DCONV_PAD_TOP (e.g. -DCONV_PAD_TOP=1)
@@ -1055,7 +1060,7 @@
 
     Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
 
-    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float) * VEC_SIZE;
+    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * VEC_SIZE;
 
     int  z_coord  = 0;
     int4 offset   = 0;
@@ -1065,15 +1070,15 @@
     VEC_FLOAT acc = 0;
 
     // Load weights
-    VEC_FLOAT w0 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z));
-    VEC_FLOAT w1 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z));
-    VEC_FLOAT w2 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z));
-    VEC_FLOAT w3 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z));
-    VEC_FLOAT w4 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z));
-    VEC_FLOAT w5 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z));
-    VEC_FLOAT w6 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z));
-    VEC_FLOAT w7 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z));
-    VEC_FLOAT w8 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z));
+    VEC_FLOAT w0 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z));
+    VEC_FLOAT w1 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z));
+    VEC_FLOAT w2 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z));
+    VEC_FLOAT w3 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z));
+    VEC_FLOAT w4 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z));
+    VEC_FLOAT w5 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z));
+    VEC_FLOAT w6 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z));
+    VEC_FLOAT w7 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z));
+    VEC_FLOAT w8 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z));
 
     // Load input values
     // z == 0
@@ -1085,27 +1090,27 @@
     offset  = y_offset + (int4)(z_coord * src_stride_z);
     offset  = min(offset, (int4)max_offset);
 
-    VEC_FLOAT values0 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
-    VEC_FLOAT values1 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
-    VEC_FLOAT values2 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+    VEC_FLOAT values0 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s0));
+    VEC_FLOAT values1 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s1));
+    VEC_FLOAT values2 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s2));
 
     // z == 1
     // z_coord can be only negative for z = 0 so we do not need to clamp it
     // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
     z_coord           = z * CONV_STRIDE_Y - (int)CONV_PAD_TOP + 1;
     offset            = y_offset + (int4)(z_coord * src_stride_z);
-    VEC_FLOAT values3 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
-    VEC_FLOAT values4 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
-    VEC_FLOAT values5 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+    VEC_FLOAT values3 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s0));
+    VEC_FLOAT values4 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s1));
+    VEC_FLOAT values5 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s2));
 
     // z == 2
     // After z = 1 we can simply add src_stride_z to offset without updating z_coord
     // However offset can be out-of-bound so we need to check if it is greater than max_offset
     offset += (int4)src_stride_z;
     offset            = min(offset, (int4)max_offset);
-    VEC_FLOAT values6 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
-    VEC_FLOAT values7 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
-    VEC_FLOAT values8 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+    VEC_FLOAT values6 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s0));
+    VEC_FLOAT values7 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s1));
+    VEC_FLOAT values8 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s2));
 
     acc = fma(values0, w0, acc);
     acc = fma(values1, w1, acc);
@@ -1121,13 +1126,13 @@
 
 #if defined(HAS_BIAS)
     Vector    biases      = CONVERT_TO_VECTOR_STRUCT(biases);
-    VEC_FLOAT bias_values = VLOAD(VEC_SIZE)(0, (__global float *)biases.ptr);
+    VEC_FLOAT bias_values = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
     acc += bias_values;
 #endif // defined(HAS_BIAS)
 
     Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
     VSTORE(VEC_SIZE)
-    (acc, 0, (__global float *)(dst.ptr));
+    (acc, 0, (__global DATA_TYPE *)(dst.ptr));
 }
 #endif // defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y)
 
@@ -1186,7 +1191,7 @@
 
     Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
 
-    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float) * VEC_SIZE;
+    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * VEC_SIZE;
 
     int  z_coord  = 0;
     int4 offset   = 0;
@@ -1199,15 +1204,15 @@
     VEC_FLOAT acc3 = 0;
 
     // Load weights
-    VEC_FLOAT w0 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z));
-    VEC_FLOAT w1 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z));
-    VEC_FLOAT w2 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z));
-    VEC_FLOAT w3 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z));
-    VEC_FLOAT w4 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z));
-    VEC_FLOAT w5 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z));
-    VEC_FLOAT w6 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z));
-    VEC_FLOAT w7 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z));
-    VEC_FLOAT w8 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z));
+    VEC_FLOAT w0 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z));
+    VEC_FLOAT w1 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z));
+    VEC_FLOAT w2 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z));
+    VEC_FLOAT w3 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z));
+    VEC_FLOAT w4 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z));
+    VEC_FLOAT w5 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z));
+    VEC_FLOAT w6 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z));
+    VEC_FLOAT w7 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z));
+    VEC_FLOAT w8 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z));
 
     // Load input values
     // z == 0
@@ -1219,40 +1224,40 @@
     offset  = y_offset + (int4)(z_coord * src_stride_z);
     offset  = min(offset, (int4)max_offset);
 
-    VEC_FLOAT values0 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
-    VEC_FLOAT values1 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
-    VEC_FLOAT values2 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
-    VEC_FLOAT values3 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+    VEC_FLOAT values0 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s0));
+    VEC_FLOAT values1 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s1));
+    VEC_FLOAT values2 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s2));
+    VEC_FLOAT values3 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s3));
 
     // z == 1
     // z_coord can be only negative for z = 0 so we do not need to clamp it
     // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
     z_coord           = z * (int)NUM_PLANES_PROCESSED - (int)CONV_PAD_TOP + 1;
     offset            = y_offset + (int4)(z_coord * src_stride_z);
-    VEC_FLOAT values4 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
-    VEC_FLOAT values5 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
-    VEC_FLOAT values6 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
-    VEC_FLOAT values7 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+    VEC_FLOAT values4 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s0));
+    VEC_FLOAT values5 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s1));
+    VEC_FLOAT values6 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s2));
+    VEC_FLOAT values7 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s3));
 
     // z == 2
     // After z = 1 we can simply add src_stride_z to offset without updating z_coord
     // However offset can be out-of-bound so we need to check if it is greater than max_offset
     offset += (int4)src_stride_z;
     offset             = min(offset, (int4)max_offset);
-    VEC_FLOAT values8  = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
-    VEC_FLOAT values9  = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
-    VEC_FLOAT values10 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
-    VEC_FLOAT values11 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+    VEC_FLOAT values8  = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s0));
+    VEC_FLOAT values9  = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s1));
+    VEC_FLOAT values10 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s2));
+    VEC_FLOAT values11 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s3));
 
     // z == 3
     // After z = 1 we can simply add src_stride_z to offset without updating z_coord
     // However offset can be out-of-bound so we need to check if it is greater than max_offset
     offset += (int4)src_stride_z;
     offset             = min(offset, (int4)max_offset);
-    VEC_FLOAT values12 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
-    VEC_FLOAT values13 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
-    VEC_FLOAT values14 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
-    VEC_FLOAT values15 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+    VEC_FLOAT values12 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s0));
+    VEC_FLOAT values13 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s1));
+    VEC_FLOAT values14 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s2));
+    VEC_FLOAT values15 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_addr + offset.s3));
 
     acc0 = fma(values0, w0, acc0);
     acc0 = fma(values1, w1, acc0);
@@ -1299,7 +1304,7 @@
 #if defined(HAS_BIAS)
     Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
 
-    VEC_FLOAT bias_values = VLOAD(VEC_SIZE)(0, (__global float *)biases.ptr);
+    VEC_FLOAT bias_values = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
 
     acc0 += bias_values;
     acc1 += bias_values;
@@ -1310,20 +1315,20 @@
     __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_step_x + y * dst_step_y + (z * NUM_PLANES_PROCESSED) * dst_step_z;
 
     VSTORE(VEC_SIZE)
-    (acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+    (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
     VSTORE(VEC_SIZE)
-    (acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+    (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
 
 #if((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
     if((z * NUM_PLANES_PROCESSED + 1) < DST_DIM_2)
 #endif // ((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
     {
         VSTORE(VEC_SIZE)
-        (acc2, 0, (__global float *)(dst_addr + 0 * dst_stride_y + 1 * dst_stride_z));
+        (acc2, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + 1 * dst_stride_z));
         VSTORE(VEC_SIZE)
-        (acc3, 0, (__global float *)(dst_addr + 1 * dst_stride_y + 1 * dst_stride_z));
+        (acc3, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + 1 * dst_stride_z));
     }
 }
 
 #endif // defined(NUM_ROWS_PROCESSED) && defined(NUM_PLANES_PROCESSED)
-#endif // defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)
\ No newline at end of file
+#endif // defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT) && defined(DATA_TYPE)
\ No newline at end of file