COMPMID-1740: Fuse batch normalization with Convolution Layer at graph level
Change-Id: I77ca51c2c72783cc26a099a6a9c3210cdbbe822d
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/797
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index dfd16e0..60307bc 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -341,22 +341,10 @@
Vector bn_mean = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_mean);
Vector bn_var = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_var);
- // In-place ops
-#ifdef IN_PLACE_W
- Tensor4D fused_w = conv_w;
-#else /* IN_PLACE_W */
- Tensor4D fused_w = CONVERT_TO_TENSOR4D_STRUCT(fused_w, NUM_CHANNELS);
-#endif /* IN_PLACE */
-#ifdef IN_PLACE_B
- Vector fused_b = conv_b;
-#else /* IN_PLACE_W */
- Vector fused_b = CONVERT_TO_VECTOR_STRUCT_NO_STEP(fused_b);
-#endif /* IN_PLACE */
-
// Conditional ops
#ifdef HAS_BIAS
Vector conv_b = CONVERT_TO_VECTOR_STRUCT_NO_STEP(conv_b);
-#endif /* USE_DEFAULT_BETA */
+#endif /* HAS_BIAS */
#ifndef USE_DEFAULT_BETA
Vector bn_beta = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_beta);
#endif /* USE_DEFAULT_BETA */
@@ -364,6 +352,19 @@
Vector bn_gamma = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_gamma);
#endif /* USE_DEFAULT_GAMMA */
+ // In-place ops
+#ifdef IN_PLACE_W
+ Tensor4D fused_w = conv_w;
+ uint fused_w_stride_x = conv_w_stride_x;
+#else /* IN_PLACE_W */
+ Tensor4D fused_w = CONVERT_TO_TENSOR4D_STRUCT(fused_w, NUM_CHANNELS);
+#endif /* IN_PLACE_W */
+#ifdef IN_PLACE_B
+ Vector fused_b = conv_b;
+#else /* IN_PLACE_B */
+ Vector fused_b = CONVERT_TO_VECTOR_STRUCT_NO_STEP(fused_b);
+#endif /* IN_PLACE_B */
+
const int current_slice = get_global_id(2) / NUM_CHANNELS;
#if defined(VEC_SIZE) && defined(LAST_ACCESSED_X)