APPBROWSER-332 Code refactoring for batchnormalization_layer.cs

Change-Id: Ib695e7551994a10355c823840d3fb6237aef0a65
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112054
Reviewed-by: Joel Liang <joel.liang@arm.com>
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs b/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs
index be1d01f..53fb515 100644
--- a/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs
+++ b/src/core/GLES_COMPUTE/cs_shaders/batchnormalization_layer.cs
@@ -24,11 +24,9 @@
 
 layout(local_size_x = LOCAL_SIZE_X, local_size_y = LOCAL_SIZE_Y, local_size_z = LOCAL_SIZE_Z) in;
 
-#include "helpers.h"
+#include "helpers_cs.h"
 
-#ifdef DATA_TYPE_FP32
-precision highp float;
-#elif defined(DATA_TYPE_FP16)
+#if defined(DATA_TYPE_FP16)
 precision mediump float;
 #endif /*DATA_TYPE_FP32*/
 
@@ -38,69 +36,50 @@
 #define INVSQRT_OP(a) inversesqrt((a))
 #define SQCVT_SAT(a) (a)
 
-layout(std140) uniform shader_params
+/** Apply batch normalization.
+ *
+ * @note The data type must be passed at compile time using "#define DATA_TYPE_NAME". e.g. "#define DATA_TYPE_FP32"
+ * @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
+ *
+ * @param[in]  src_ptr     Pointer to the first source tensor. Supported data types: F16/F32
+ * @param[in]  src_attrs   The attributes of the source tensor
+ * @param[out] dst_ptr     Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in]  dst_attrs   The attributes of the destination tensor
+ * @param[in]  mean_ptr    Pointer to the mean source tensor. Supported data types: same as @p src_ptr
+ * @param[in]  mean_attrs  The attributes of the mean tensor
+ * @param[in]  var_ptr     Pointer to the var tensor. Supported data types: same as @p src_ptr
+ * @param[in]  var_attrs   The attributes of the var tensor
+ * @param[in]  beta_ptr    Pointer to the beta source tensor. Supported data types: same as @p src_ptr
+ * @param[in]  beta_attrs  The attributes of the beta tensor
+ * @param[in]  gamma_ptr   Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
+ * @param[in]  gamma_attrs The attributes of the gamma tensor
+ */
+SHADER_PARAMS_DECLARATION
 {
-    TENSOR3D_PARAM_DECLARATION(src);
-    TENSOR3D_PARAM_DECLARATION(dst);
-    VECTOR_PARAM_DECLARATION(mean);
-    VECTOR_PARAM_DECLARATION(var);
-    VECTOR_PARAM_DECLARATION(beta);
-    VECTOR_PARAM_DECLARATION(gamma);
+    Tensor3DAttributes src_attrs;
+    Tensor3DAttributes dst_attrs;
+    VectorAttributes   mean_attrs;
+    VectorAttributes   var_attrs;
+    VectorAttributes   beta_attrs;
+    VectorAttributes   gamma_attrs;
 };
 
 #ifdef DATA_TYPE_FP32
-BUFFER_DECLARATION(src, 1, float, readonly);
-BUFFER_DECLARATION(dst, 2, float, writeonly);
-BUFFER_DECLARATION(mean, 3, float, readonly);
-BUFFER_DECLARATION(var, 4, float, readonly);
-BUFFER_DECLARATION(beta, 5, float, readonly);
-BUFFER_DECLARATION(gamma, 6, float, readonly);
+TENSOR_DECLARATION(1, srcBuffer, float, src_ptr, src_shift, 2, readonly);
+TENSOR_DECLARATION(2, dstBuffer, float, dst_ptr, dst_shift, 2, writeonly);
+TENSOR_DECLARATION(3, meanBuffer, float, mean_ptr, mean_shift, 2, readonly);
+TENSOR_DECLARATION(4, varBuffer, float, var_ptr, var_shift, 2, readonly);
+TENSOR_DECLARATION(5, betaBuffer, float, beta_ptr, beta_shift, 2, readonly);
+TENSOR_DECLARATION(6, gammaBuffer, float, gamma_ptr, gamma_shift, 2, readonly);
 
-/** Apply batch normalization.
- *
- * @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
- *
- * @param[in]  src_ptr                             Pointer to the first source tensor. Supported data types: F32
- * @param[in]  src_stride_x                        Stride of the first source tensor in X dimension (in bytes)
- * @param[in]  src_step_x                          src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src_stride_y                        Stride of the first source tensor in Y dimension (in bytes)
- * @param[in]  src_step_y                          src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  src_stride_z                        Stride of the first source tensor in Z dimension (in bytes)
- * @param[in]  src_step_z                          src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  src_offset_first_element_in_bytes   The offset of the first element in the first source tensor
- * @param[out] dst_ptr                             Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in]  dst_stride_x                        Stride of the destination tensor in X dimension (in bytes)
- * @param[in]  dst_step_x                          dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                        Stride of the destination tensor in Y dimension (in bytes)
- * @param[in]  dst_step_y                          dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_stride_z                        Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  dst_step_z                          dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes   The offset of the first element in the destination tensor
- * @param[in]  mean_ptr                            Pointer to the mean source tensor. Supported data types: same as @p src_ptr
- * @param[in]  mean_stride_x                       Stride of the mean source tensor in X dimension (in bytes)
- * @param[in]  mean_step_x                         mean_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  mean_offset_first_element_in_bytes  The offset of the first element in the mean source tensor
- * @param[in]  var_ptr                             Pointer to the var tensor. Supported data types: same as @p src_ptr
- * @param[in]  var_stride_x                        Stride of the var tensor in X dimension (in bytes)
- * @param[in]  var_step_x                          var_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  var_offset_first_element_in_bytes   The offset of the first element in the var source tensor
- * @param[in]  beta_ptr                            Pointer to the beta source tensor. Supported data types: same as @p src_ptr
- * @param[in]  beta_stride_x                       Stride of the beta source tensor in X dimension (in bytes)
- * @param[in]  beta_step_x                         beta_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  beta_offset_first_element_in_bytes  The offset of the first element in the beta source tensor
- * @param[in]  gamma_ptr                           Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
- * @param[in]  gamma_stride_x                      Stride of the gamma source tensor in X dimension (in bytes)
- * @param[in]  gamma_step_x                        gamma_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  gamma_offset_first_element_in_bytes The offset of the first element in the gamma source tensor
- */
 void main(void)
 {
-    Tensor3D src   = CONVERT_TO_TENSOR3D_STRUCT(src);
-    Tensor3D dst   = CONVERT_TO_TENSOR3D_STRUCT(dst);
-    Vector   mean  = CONVERT_TO_VECTOR_STRUCT(mean);
-    Vector   var   = CONVERT_TO_VECTOR_STRUCT(var);
-    Vector   beta  = CONVERT_TO_VECTOR_STRUCT(beta);
-    Vector   gamma = CONVERT_TO_VECTOR_STRUCT(gamma);
+    Tensor3DIterator src_iter   = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
+    Tensor3DIterator dst_iter   = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
+    VectorIterator   mean_iter  = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
+    VectorIterator   var_iter   = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
+    VectorIterator   beta_iter  = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
+    VectorIterator   gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
 
     float input_value = 0.f;
     float denominator = 0.f;
@@ -111,76 +90,38 @@
 
     uint current_slice = gl_GlobalInvocationID.z;
 
-    input_value = src_ptr[src.current_offset];
-    denominator = var_ptr[var.current_offset + (current_slice * var.stride_x) >> 2];
+    input_value = LOAD_CURRENT_ITEM(src_ptr, src_iter);
+    denominator = LOAD(var_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(var_iter, current_slice * var_attrs.stride_x));
     denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
 
     // Calculate x bar and store results
-    numerator = mean_ptr[mean.current_offset + (current_slice * mean.stride_x) >> 2];
+    numerator = LOAD(mean_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(mean_iter, current_slice * mean_attrs.stride_x));
     numerator = SUB_OP(input_value, numerator);
     x_bar     = MUL_OP(numerator, denominator);
 
-    gamma_param = gamma_ptr[gamma.current_offset + (current_slice * beta.stride_x) >> 2];
-    beta_param  = beta_ptr[beta.current_offset + (current_slice * beta.stride_x) >> 2];
+    gamma_param = LOAD(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
+    beta_param  = LOAD(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
 
-    dst_ptr[dst.current_offset] = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
+    STORE_CURRENT_ITEM(dst_ptr, dst_iter, ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
 }
 
 #elif defined(DATA_TYPE_FP16)
-BUFFER_DECLARATION(src, 1, uvec2, readonly);
-BUFFER_DECLARATION(dst, 2, uvec2, writeonly);
-BUFFER_DECLARATION(mean, 3, uvec2, readonly);
-BUFFER_DECLARATION(var, 4, uvec2, readonly);
-BUFFER_DECLARATION(beta, 5, uvec2, readonly);
-BUFFER_DECLARATION(gamma, 6, uvec2, readonly);
+TENSOR_DECLARATION(1, srcBuffer, uvec2, src_ptr, src_shift, 3, readonly);
+TENSOR_DECLARATION(2, dstBuffer, uvec2, dst_ptr, dst_shift, 3, writeonly);
+TENSOR_DECLARATION(3, meanBuffer, uvec2, mean_ptr, mean_shift, 3, readonly);
+TENSOR_DECLARATION(4, varBuffer, uvec2, var_ptr, var_shift, 3, readonly);
+TENSOR_DECLARATION(5, betaBuffer, uvec2, beta_ptr, beta_shift, 3, readonly);
+TENSOR_DECLARATION(6, gammaBuffer, uvec2, gamma_ptr, gamma_shift, 3, readonly);
 
-/** Apply batch normalization.
- *
- * @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
- *
- * @param[in]  src_ptr                             Pointer to the first source tensor. Supported data types: F16
- * @param[in]  src_stride_x                        Stride of the first source tensor in X dimension (in bytes)
- * @param[in]  src_step_x                          src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src_stride_y                        Stride of the first source tensor in Y dimension (in bytes)
- * @param[in]  src_step_y                          src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  src_stride_z                        Stride of the first source tensor in Z dimension (in bytes)
- * @param[in]  src_step_z                          src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  src_offset_first_element_in_bytes   The offset of the first element in the first source tensor
- * @param[out] dst_ptr                             Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in]  dst_stride_x                        Stride of the destination tensor in X dimension (in bytes)
- * @param[in]  dst_step_x                          dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                        Stride of the destination tensor in Y dimension (in bytes)
- * @param[in]  dst_step_y                          dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_stride_z                        Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  dst_step_z                          dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes   The offset of the first element in the destination tensor
- * @param[in]  mean_ptr                            Pointer to the mean source tensor. Supported data types: same as @p src_ptr
- * @param[in]  mean_stride_x                       Stride of the mean source tensor in X dimension (in bytes)
- * @param[in]  mean_step_x                         mean_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  mean_offset_first_element_in_bytes  The offset of the first element in the mean source tensor
- * @param[in]  var_ptr                             Pointer to the var tensor. Supported data types: same as @p src_ptr
- * @param[in]  var_stride_x                        Stride of the var tensor in X dimension (in bytes)
- * @param[in]  var_step_x                          var_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  var_offset_first_element_in_bytes   The offset of the first element in the var source tensor
- * @param[in]  beta_ptr                            Pointer to the beta source tensor. Supported data types: same as @p src_ptr
- * @param[in]  beta_stride_x                       Stride of the beta source tensor in X dimension (in bytes)
- * @param[in]  beta_step_x                         beta_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  beta_offset_first_element_in_bytes  The offset of the first element in the beta source tensor
- * @param[in]  gamma_ptr                           Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
- * @param[in]  gamma_stride_x                      Stride of the gamma source tensor in X dimension (in bytes)
- * @param[in]  gamma_step_x                        gamma_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  gamma_offset_first_element_in_bytes The offset of the first element in the gamma source tensor
- */
 void main(void)
 {
-    Tensor3D src   = CONVERT_TO_TENSOR3D_STRUCT_FP16(src);
-    Tensor3D dst   = CONVERT_TO_TENSOR3D_STRUCT_FP16(dst);
-    Vector   mean  = CONVERT_TO_VECTOR_STRUCT_FP16(mean);
-    Vector   var   = CONVERT_TO_VECTOR_STRUCT_FP16(var);
-    Vector   beta  = CONVERT_TO_VECTOR_STRUCT_FP16(beta);
-    Vector   gamma = CONVERT_TO_VECTOR_STRUCT_FP16(gamma);
+    Tensor3DIterator src_iter   = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
+    Tensor3DIterator dst_iter   = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
+    VectorIterator   mean_iter  = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
+    VectorIterator   var_iter   = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
+    VectorIterator   beta_iter  = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
+    VectorIterator   gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
 
-    uvec2 packed_s[5];
     vec4  unpacked_s[5];
     float denominator;
     float numerator;
@@ -190,16 +131,11 @@
     vec4  result;
 
     uint current_slice = gl_GlobalInvocationID.z;
-    packed_s[0]        = src_ptr[src.current_offset >> 3];
-    packed_s[1]        = var_ptr[(var.current_offset + current_slice * var.stride_x) >> 3];
-    packed_s[2]        = mean_ptr[(mean.current_offset + current_slice * mean.stride_x) >> 3];
-    packed_s[3]        = gamma_ptr[(gamma.current_offset + current_slice * beta.stride_x) >> 3];
-    packed_s[4]        = beta_ptr[(beta.current_offset + current_slice * beta.stride_x) >> 3];
-    unpacked_s[0]      = vec4(unpackHalf2x16(packed_s[0].x), unpackHalf2x16(packed_s[0].y));
-    unpacked_s[1]      = vec4(unpackHalf2x16(packed_s[1].x), unpackHalf2x16(packed_s[1].y));
-    unpacked_s[2]      = vec4(unpackHalf2x16(packed_s[2].x), unpackHalf2x16(packed_s[2].y));
-    unpacked_s[3]      = vec4(unpackHalf2x16(packed_s[3].x), unpackHalf2x16(packed_s[3].y));
-    unpacked_s[4]      = vec4(unpackHalf2x16(packed_s[4].x), unpackHalf2x16(packed_s[4].y));
+    unpacked_s[0]      = LOAD_UNPACK4_CURRENT_ITEM_HALF(src_ptr, src_iter);
+    unpacked_s[1]      = LOAD_UNPACK4_HALF(var_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(var_iter, current_slice * var_attrs.stride_x));
+    unpacked_s[2]      = LOAD_UNPACK4_HALF(mean_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(mean_iter, current_slice * mean_attrs.stride_x));
+    unpacked_s[3]      = LOAD_UNPACK4_HALF(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
+    unpacked_s[4]      = LOAD_UNPACK4_HALF(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
 
     if((current_slice % uint(4)) == uint(0))
     {
@@ -214,7 +150,7 @@
         beta_param  = unpacked_s[4].x;
         result      = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
 
-        dst_ptr[dst.current_offset >> 3] = uvec2(packHalf2x16(result.xy), packHalf2x16(result.zw));
+        STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
     }
     else if((current_slice % uint(4)) == uint(1))
     {
@@ -229,7 +165,7 @@
         beta_param  = unpacked_s[4].y;
         result      = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
 
-        dst_ptr[dst.current_offset >> 3] = uvec2(packHalf2x16(result.xy), packHalf2x16(result.zw));
+        STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
     }
     else if((current_slice % uint(4)) == uint(2))
     {
@@ -244,7 +180,7 @@
         beta_param  = unpacked_s[4].z;
         result      = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
 
-        dst_ptr[dst.current_offset >> 3] = uvec2(packHalf2x16(result.xy), packHalf2x16(result.zw));
+        STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
     }
     else
     {
@@ -259,7 +195,7 @@
         beta_param  = unpacked_s[4].w;
         result      = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
 
-        dst_ptr[dst.current_offset >> 3] = uvec2(packHalf2x16(result.xy), packHalf2x16(result.zw));
+        STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
     }
 }
 #endif /*DATA_TYPE_FP16*/