Improve Winograd performance on OpenCL

- Performs more output elements per work-item in the case of Fp16
computation in Winograd Input/Output transform

Resolves COMPMID-6018

Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Change-Id: If5e6f5182eff8c1f05a3505c437d0a997490f0bd
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9447
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/nhwc/winograd_input_transform.cl b/src/core/CL/cl_kernels/nhwc/winograd_input_transform.cl
index ba7b13b..7341336 100644
--- a/src/core/CL/cl_kernels/nhwc/winograd_input_transform.cl
+++ b/src/core/CL/cl_kernels/nhwc/winograd_input_transform.cl
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,42 +24,42 @@
 #include "helpers.h"
 #include "tile_helpers.h"
 
-#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact)                     \
-    ({                                                              \
-        comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6;            \
-        comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5;            \
-        comm_fact.s2 = 2.5f * tmp.s3;                               \
-        comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
-        comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6;    \
-        comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4;        \
-        comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
+#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact)                                           \
+    ({                                                                                    \
+        comm_fact.s0 = tmp.s2 - (DATA_TYPE)4.25f * tmp.s4 + tmp.s6;                       \
+        comm_fact.s1 = tmp.s1 - (DATA_TYPE)4.25f * tmp.s3 + tmp.s5;                       \
+        comm_fact.s2 = (DATA_TYPE)2.5f * tmp.s3;                                          \
+        comm_fact.s3 = (DATA_TYPE)0.5f * tmp.s1 + (DATA_TYPE)2.f * tmp.s5 - comm_fact.s2; \
+        comm_fact.s4 = (DATA_TYPE)0.25f * tmp.s2 - (DATA_TYPE)1.25f * tmp.s4 + tmp.s6;    \
+        comm_fact.s5 = (DATA_TYPE)4.f * tmp.s2 + tmp.s6 - (DATA_TYPE)5.f * tmp.s4;        \
+        comm_fact.s6 = (DATA_TYPE)2.f * tmp.s1 + (DATA_TYPE)0.5f * tmp.s5 - comm_fact.s2; \
         \
-        out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
-        out.s1 = comm_fact.s0 + comm_fact.s1;                       \
-        out.s2 = comm_fact.s0 - comm_fact.s1;                       \
-        out.s3 = comm_fact.s3 + comm_fact.s4;                       \
-        out.s4 = comm_fact.s4 - comm_fact.s3;                       \
-        out.s5 = comm_fact.s5 + comm_fact.s6;                       \
-        out.s6 = comm_fact.s5 - comm_fact.s6;                       \
-        out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
+        out.s0 = tmp.s0 - tmp.s6 + (DATA_TYPE)5.25f * tmp.s4 - (DATA_TYPE)5.25f * tmp.s2; \
+        out.s1 = comm_fact.s0 + comm_fact.s1;                                             \
+        out.s2 = comm_fact.s0 - comm_fact.s1;                                             \
+        out.s3 = comm_fact.s3 + comm_fact.s4;                                             \
+        out.s4 = comm_fact.s4 - comm_fact.s3;                                             \
+        out.s5 = comm_fact.s5 + comm_fact.s6;                                             \
+        out.s6 = comm_fact.s5 - comm_fact.s6;                                             \
+        out.s7 = tmp.s7 - tmp.s1 + (DATA_TYPE)5.25f * tmp.s3 - (DATA_TYPE)5.25f * tmp.s5; \
     })
 
-#define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact)                                                    \
-    ({                                                                                             \
-        comm_fact.s0 = 36.0f * tmp.s2 - 13.0f * tmp.s4 + tmp.s6;                                   \
-        comm_fact.s1 = 36.0f * tmp.s1 - 13.0f * tmp.s3 + 1.0f * tmp.s5;                            \
-        comm_fact.s2 = 9.0f * tmp.s2 - 10.0f * tmp.s4 + tmp.s6;                                    \
-        comm_fact.s3 = 18.0f * tmp.s1 - 20.0f * tmp.s3 + 2.0f * tmp.s5;                            \
-        comm_fact.s4 = 4.0f * tmp.s2 - 5.0f * tmp.s4 + tmp.s6;                                     \
-        comm_fact.s5 = 12.0f * tmp.s1 - 15.0f * tmp.s3 + 3.0f * tmp.s5;                            \
-        out.s0       = -36.0f * tmp.s0 + 49.0f * tmp.s2 + -14.0f * tmp.s4 + tmp.s6;                \
-        out.s1       = comm_fact.s0 - comm_fact.s1;                                                \
-        out.s2       = comm_fact.s0 + comm_fact.s1;                                                \
-        out.s3       = comm_fact.s2 - comm_fact.s3;                                                \
-        out.s4       = comm_fact.s2 + comm_fact.s3;                                                \
-        out.s5       = comm_fact.s4 - comm_fact.s5;                                                \
-        out.s6       = comm_fact.s4 + comm_fact.s5;                                                \
-        out.s7       = -36.0f * tmp.s1 + 0.0f * tmp.s2 + 49.0f * tmp.s3 - 14.0f * tmp.s5 + tmp.s7; \
+#define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact)                                                                                                \
+    ({                                                                                                                                         \
+        comm_fact.s0 = (DATA_TYPE)36.0f * tmp.s2 - (DATA_TYPE)13.0f * tmp.s4 + tmp.s6;                                                         \
+        comm_fact.s1 = (DATA_TYPE)36.0f * tmp.s1 - (DATA_TYPE)13.0f * tmp.s3 + (DATA_TYPE)1.0f * tmp.s5;                                       \
+        comm_fact.s2 = (DATA_TYPE)9.0f * tmp.s2 - (DATA_TYPE)10.0f * tmp.s4 + tmp.s6;                                                          \
+        comm_fact.s3 = (DATA_TYPE)18.0f * tmp.s1 - (DATA_TYPE)20.0f * tmp.s3 + (DATA_TYPE)2.0f * tmp.s5;                                       \
+        comm_fact.s4 = (DATA_TYPE)4.0f * tmp.s2 - (DATA_TYPE)5.0f * tmp.s4 + tmp.s6;                                                           \
+        comm_fact.s5 = (DATA_TYPE)12.0f * tmp.s1 - (DATA_TYPE)15.0f * tmp.s3 + (DATA_TYPE)3.0f * tmp.s5;                                       \
+        out.s0       = -(DATA_TYPE)36.0f * tmp.s0 + (DATA_TYPE)49.0f * tmp.s2 + -(DATA_TYPE)14.0f * tmp.s4 + tmp.s6;                           \
+        out.s1       = comm_fact.s0 - comm_fact.s1;                                                                                            \
+        out.s2       = comm_fact.s0 + comm_fact.s1;                                                                                            \
+        out.s3       = comm_fact.s2 - comm_fact.s3;                                                                                            \
+        out.s4       = comm_fact.s2 + comm_fact.s3;                                                                                            \
+        out.s5       = comm_fact.s4 - comm_fact.s5;                                                                                            \
+        out.s6       = comm_fact.s4 + comm_fact.s5;                                                                                            \
+        out.s7       = -(DATA_TYPE)36.0f * tmp.s1 + (DATA_TYPE)0.0f * tmp.s2 + (DATA_TYPE)49.0f * tmp.s3 - (DATA_TYPE)14.0f * tmp.s5 + tmp.s7; \
     })
 
 #if defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
@@ -113,9 +113,13 @@
     const int _INUM_TILES_X,
     const int _INUM_TILES_Y)
 {
-    const int cout = GET_SPATIAL_IDX(0, 1, 0); // OFM
-    const int mout = GET_SPATIAL_IDX(1, 1, 0); // NUM_TILES_X x NUM_TILES_Y
+    const int cout = GET_SPATIAL_IDX(0, N0, 0); // OFM
+    const int mout = GET_SPATIAL_IDX(1, 1, 0);  // NUM_TILES_X x NUM_TILES_Y
+#if defined(IS_BATCHED)
     const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#else                                          // defined(IS_BATCHED)
+    const int bout = 0; // BATCH SIZE IDX
+#endif                                         // defined(IS_BATCHED)
 
     int x = (mout % _INUM_TILES_X) * OUTPUT_TILE_W;
     int y = (mout / _INUM_TILES_X) * OUTPUT_TILE_H;
@@ -124,8 +128,8 @@
 
 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
 
-    TILE(DATA_TYPE, 6, 1, in);
-    TILE(DATA_TYPE, 6, 1, out);
+    TILE(DATA_TYPE, 6, N0, in);
+    TILE(DATA_TYPE, 6, N0, out);
 
     // Initialize the input tile
     LOOP_UNROLLING(int, i, 0, 1, 6,
@@ -134,22 +138,22 @@
     })
 
 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
-    T_LOAD_NHWC(DATA_TYPE, 1, 6, 1, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
+    T_LOAD_NHWC(DATA_TYPE, 1, 6, N0, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
 #else  // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
-    T_LOAD_NHWC(DATA_TYPE, 6, 1, 1, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
+    T_LOAD_NHWC(DATA_TYPE, 6, 1, N0, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
 #endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
 
-    TILE(DATA_TYPE, 6, 1, com);
+    TILE(DATA_TYPE, 6, N0, com);
 
     LOOP_UNROLLING(int, i, 0, 1, 6,
     {
-        in[i].v *= 4.0f;
+        in[i].v *= (DATA_TYPE)4.0f;
     })
 
-    com[0].v = in[2].v - 4.f * in[0].v;
-    com[1].v = in[3].v - 4.f * in[1].v;
-    com[2].v = in[4].v - 4.f * in[2].v;
-    com[3].v = in[5].v - 4.f * in[3].v;
+    com[0].v = in[2].v - (DATA_TYPE)4.f * in[0].v;
+    com[1].v = in[3].v - (DATA_TYPE)4.f * in[1].v;
+    com[2].v = in[4].v - (DATA_TYPE)4.f * in[2].v;
+    com[3].v = in[5].v - (DATA_TYPE)4.f * in[3].v;
     com[4].v = in[3].v - in[1].v;
     com[4].v = com[4].v + com[4].v;
     com[5].v = in[4].v - in[2].v;
@@ -169,11 +173,11 @@
         dst_indirect_y[i].v += bout *_INUM_TILES_X *_INUM_TILES_Y * 6;
     })
 
-    T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, 6, 1, 0, BUFFER, dst, cout, dst_stride_y, false, out, dst_indirect_y);
+    T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, 6, N0, 0, BUFFER, dst, cout, dst_stride_y, false, out, dst_indirect_y);
 
 #else  // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
 
-    TILE(DATA_TYPE, 36, 1, in);
+    TILE(DATA_TYPE, 36, N0, in);
 
     // Initialize the input tile
     LOOP_UNROLLING(int, i, 0, 1, 36,
@@ -182,10 +186,10 @@
     })
 
     // Load the tile from a NHWC tensor
-    T_LOAD_NHWC(DATA_TYPE, 6, 6, 1, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
+    T_LOAD_NHWC(DATA_TYPE, 6, 6, N0, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
 
-    TILE(DATA_TYPE, 6, 1, com);
-    TILE(DATA_TYPE, 36, 1, tmp);
+    TILE(DATA_TYPE, 6, N0, com);
+    TILE(DATA_TYPE, 36, N0, tmp);
 
     LOOP_UNROLLING(int, i, 0, 1, 6,
     {
@@ -204,14 +208,14 @@
         tmp[i + 5 * 6].v = com[3].v - com[1].v;
     })
 
-    TILE(DATA_TYPE, 36, 1, out);
+    TILE(DATA_TYPE, 36, N0, out);
 
     LOOP_UNROLLING(int, i, 0, 1, 6,
     {
-        com[0].v         = tmp[i * 6 + 2].v - 4.f *tmp[i * 6 + 0].v;
-        com[1].v         = tmp[i * 6 + 3].v - 4.f *tmp[i * 6 + 1].v;
-        com[2].v         = tmp[i * 6 + 4].v - 4.f *tmp[i * 6 + 2].v;
-        com[3].v         = tmp[i * 6 + 5].v - 4.f *tmp[i * 6 + 3].v;
+        com[0].v         = tmp[i * 6 + 2].v - (DATA_TYPE)4.f *tmp[i * 6 + 0].v;
+        com[1].v         = tmp[i * 6 + 3].v - (DATA_TYPE)4.f *tmp[i * 6 + 1].v;
+        com[2].v         = tmp[i * 6 + 4].v - (DATA_TYPE)4.f *tmp[i * 6 + 2].v;
+        com[3].v         = tmp[i * 6 + 5].v - (DATA_TYPE)4.f *tmp[i * 6 + 3].v;
         com[4].v         = tmp[i * 6 + 3].v - tmp[i * 6 + 1].v;
         com[4].v         = com[4].v + com[4].v;
         com[5].v         = tmp[i * 6 + 4].v - tmp[i * 6 + 2].v;
@@ -232,7 +236,7 @@
         dst_indirect_y[i].v += bout *_INUM_TILES_X *_INUM_TILES_Y * 36;
     })
 
-    T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, 36, 1, 0, BUFFER, dst, cout, dst_stride_y, false, out, dst_indirect_y);
+    T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, 36, N0, 0, BUFFER, dst, cout, dst_stride_y, false, out, dst_indirect_y);
 #endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
 }
 #endif // defined(WINOGRAD_INPUT_TRANSFORM_4X4_3X3_STEPZ1_NHWC) || defined(WINOGRAD_INPUT_TRANSFORM_4X1_3X1_STEPZ1_NHWC) || defined(WINOGRAD_INPUT_TRANSFORM_1X4_1X3_STEPZ1_NHWC)
@@ -287,7 +291,11 @@
 {
     const int cout = GET_SPATIAL_IDX(0, 1, 0); // OFM
     const int mout = GET_SPATIAL_IDX(1, 1, 0); // NUM_TILES_X x NUM_TILES_Y
+#if defined(IS_BATCHED)
     const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#else                                          // defined(IS_BATCHED)
+    const int bout = 0; // BATCH SIZE IDX
+#endif                                         // defined(IS_BATCHED)
 
     int x = (mout % _INUM_TILES_X) * OUTPUT_TILE_W;
     int y = (mout / _INUM_TILES_X) * OUTPUT_TILE_H;
@@ -306,27 +314,27 @@
     })
 
 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
-    T_LOAD_NHWC(DATA_TYPE, 1, 8, 1, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
+    T_LOAD_NHWC(DATA_TYPE, 1, 8, N0, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
 #else  // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
-    T_LOAD_NHWC(DATA_TYPE, 8, 1, 1, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
+    T_LOAD_NHWC(DATA_TYPE, 8, 1, N0, BUFFER, src, bout, y, x, cout, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, in);
 #endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
 
     TILE(DATA_TYPE, 1, 8, com);
 
-    com[0].s[0] = in[2].v - 4.25f * in[4].v + in[6].v;
-    com[0].s[1] = in[1].v - 4.25f * in[3].v + in[5].v;
-    com[0].s[2] = 0.5f * in[1].v - 2.5f * in[3].v + 2.0f * in[5].v;
-    com[0].s[3] = 0.25f * in[2].v - 1.25f * in[4].v + in[6].v;
-    com[0].s[4] = 4.0f * in[2].v - 5.0f * in[4].v + in[6].v;
-    com[0].s[5] = 2.0f * in[1].v - 2.5f * in[3].v + 0.5f * in[5].v;
-    out[0].s[0] = in[0].v - 5.25f * in[2].v + 5.25f * in[4].v - in[6].v;
+    com[0].s[0] = in[2].v - (DATA_TYPE)4.25f * in[4].v + in[6].v;
+    com[0].s[1] = in[1].v - (DATA_TYPE)4.25f * in[3].v + in[5].v;
+    com[0].s[2] = (DATA_TYPE)0.5f * in[1].v - (DATA_TYPE)2.5f * in[3].v + (DATA_TYPE)2.0f * in[5].v;
+    com[0].s[3] = (DATA_TYPE)0.25f * in[2].v - (DATA_TYPE)1.25f * in[4].v + in[6].v;
+    com[0].s[4] = (DATA_TYPE)4.0f * in[2].v - (DATA_TYPE)5.0f * in[4].v + in[6].v;
+    com[0].s[5] = (DATA_TYPE)2.0f * in[1].v - (DATA_TYPE)2.5f * in[3].v + (DATA_TYPE)0.5f * in[5].v;
+    out[0].s[0] = in[0].v - 5.25f * in[2].v + (DATA_TYPE)5.25f * in[4].v - in[6].v;
     out[1].s[0] = com[0].s[0] + com[0].s[1];
     out[2].s[0] = com[0].s[0] - com[0].s[1];
     out[3].s[0] = com[0].s[3] + com[0].s[2];
     out[4].s[0] = com[0].s[3] - com[0].s[2];
     out[5].s[0] = com[0].s[4] + com[0].s[5];
     out[6].s[0] = com[0].s[4] - com[0].s[5];
-    out[7].s[0] = -in[1].v + 5.25f * in[3].v - 5.25f * in[5].v + in[7].v;
+    out[7].s[0] = -in[1].v + (DATA_TYPE)5.25f * in[3].v - (DATA_TYPE)5.25f * in[5].v + in[7].v;
 
     TILE(uint, 8, 1, dst_indirect_y);
 
@@ -378,20 +386,20 @@
 
     LOOP_UNROLLING(int, i, 0, 1, 8,
     {
-        com[0].s[0]         = tmp[i].s[2] - 4.25f * tmp[i].s[4] + tmp[i].s[6];
-        com[0].s[1]         = tmp[i].s[1] - 4.25f * tmp[i].s[3] + tmp[i].s[5];
-        com[0].s[2]         = 0.5f * tmp[i].s[1] - 2.5f * tmp[i].s[3] + 2.0f * tmp[i].s[5];
-        com[0].s[3]         = 0.25f * tmp[i].s[2] - 1.25f * tmp[i].s[4] + tmp[i].s[6];
-        com[0].s[4]         = 4.0f * tmp[i].s[2] - 5.0f * tmp[i].s[4] + tmp[i].s[6];
-        com[0].s[5]         = 2.0f * tmp[i].s[1] - 2.5f * tmp[i].s[3] + 0.5f * tmp[i].s[5];
-        out[i * 8 + 0].s[0] = tmp[i].s[0] - 5.25f * tmp[i].s[2] + 5.25f * tmp[i].s[4] - tmp[i].s[6];
+        com[0].s[0]         = tmp[i].s[2] - (DATA_TYPE)4.25f * tmp[i].s[4] + tmp[i].s[6];
+        com[0].s[1]         = tmp[i].s[1] - (DATA_TYPE)4.25f * tmp[i].s[3] + tmp[i].s[5];
+        com[0].s[2]         = (DATA_TYPE)0.5f * tmp[i].s[1] - (DATA_TYPE)2.5f * tmp[i].s[3] + (DATA_TYPE)2.0f * tmp[i].s[5];
+        com[0].s[3]         = (DATA_TYPE)0.25f * tmp[i].s[2] - (DATA_TYPE)1.25f * tmp[i].s[4] + tmp[i].s[6];
+        com[0].s[4]         = (DATA_TYPE)4.0f * tmp[i].s[2] - (DATA_TYPE)5.0f * tmp[i].s[4] + tmp[i].s[6];
+        com[0].s[5]         = (DATA_TYPE)2.0f * tmp[i].s[1] - (DATA_TYPE)2.5f * tmp[i].s[3] + (DATA_TYPE)0.5f * tmp[i].s[5];
+        out[i * 8 + 0].s[0] = tmp[i].s[0] - (DATA_TYPE)5.25f * tmp[i].s[2] + (DATA_TYPE)5.25f * tmp[i].s[4] - tmp[i].s[6];
         out[i * 8 + 1].s[0] = com[0].s[0] + com[0].s[1];
         out[i * 8 + 2].s[0] = com[0].s[0] - com[0].s[1];
         out[i * 8 + 3].s[0] = com[0].s[3] + com[0].s[2];
         out[i * 8 + 4].s[0] = com[0].s[3] - com[0].s[2];
         out[i * 8 + 5].s[0] = com[0].s[4] + com[0].s[5];
         out[i * 8 + 6].s[0] = com[0].s[4] - com[0].s[5];
-        out[i * 8 + 7].s[0] = -tmp[i].s[1] + 5.25f * tmp[i].s[3] - 5.25f * tmp[i].s[5] + tmp[i].s[7];
+        out[i * 8 + 7].s[0] = -tmp[i].s[1] + (DATA_TYPE)5.25f * tmp[i].s[3] - (DATA_TYPE)5.25f * tmp[i].s[5] + tmp[i].s[7];
     })
 
     TILE(uint, 64, 1, dst_indirect_y);
@@ -458,7 +466,11 @@
 {
     const int cout = GET_SPATIAL_IDX(0, 1, 0); // OFM
     const int mout = GET_SPATIAL_IDX(1, 1, 0); // NUM_TILES_X x NUM_TILES_Y
+#if defined(IS_BATCHED)
     const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#else                                          // defined(IS_BATCHED)
+    const int bout = 0; // BATCH SIZE IDX
+#endif                                         // defined(IS_BATCHED)
 
     int x = (mout % _INUM_TILES_X) * OUTPUT_TILE_W;
     int y = (mout / _INUM_TILES_X) * OUTPUT_TILE_H;
@@ -489,20 +501,20 @@
 
     TILE(DATA_TYPE, 1, 8, com) = { { { 0 } } };
 
-    com[0].s[0] = 36.0f * in[2].v - 13.0f * in[4].v + in[6].v;
-    com[0].s[1] = 36.0f * in[1].v - 13.0f * in[3].v + 1.0f * in[5].v;
-    com[0].s[2] = 9.0f * in[2].v - 10.0f * in[4].v + in[6].v;
-    com[0].s[3] = 18.0f * in[1].v - 20.0f * in[3].v + 2.0f * in[5].v;
-    com[0].s[4] = 4.0f * in[2].v - 5.0f * in[4].v + in[6].v;
-    com[0].s[5] = 12.0f * in[1].v - 15.0f * in[3].v + 3.0f * in[5].v;
-    out[0].s[0] = -36.0f * in[0].v + 49.0f * in[2].v + -14.0f * in[4].v + in[6].v;
+    com[0].s[0] = (DATA_TYPE)36.0f * in[2].v - (DATA_TYPE)13.0f * in[4].v + in[6].v;
+    com[0].s[1] = (DATA_TYPE)36.0f * in[1].v - (DATA_TYPE)13.0f * in[3].v + (DATA_TYPE)1.0f * in[5].v;
+    com[0].s[2] = (DATA_TYPE)9.0f * in[2].v - (DATA_TYPE)10.0f * in[4].v + in[6].v;
+    com[0].s[3] = (DATA_TYPE)18.0f * in[1].v - (DATA_TYPE)20.0f * in[3].v + (DATA_TYPE)2.0f * in[5].v;
+    com[0].s[4] = (DATA_TYPE)4.0f * in[2].v - (DATA_TYPE)5.0f * in[4].v + in[6].v;
+    com[0].s[5] = (DATA_TYPE)12.0f * in[1].v - (DATA_TYPE)15.0f * in[3].v + (DATA_TYPE)3.0f * in[5].v;
+    out[0].s[0] = (DATA_TYPE) - 36.0f * in[0].v + (DATA_TYPE)49.0f * in[2].v + -(DATA_TYPE)14.0f * in[4].v + in[6].v;
     out[1].s[0] = com[0].s[0] - com[0].s[1];
     out[2].s[0] = com[0].s[0] + com[0].s[1];
     out[3].s[0] = com[0].s[2] - com[0].s[3];
     out[4].s[0] = com[0].s[2] + com[0].s[3];
     out[5].s[0] = com[0].s[4] - com[0].s[5];
     out[6].s[0] = com[0].s[4] + com[0].s[5];
-    out[7].s[0] = -36.0f * in[1].v + 0.0f * in[2].v + 49.0f * in[3].v - 14.0f * in[5].v + in[7].v;
+    out[7].s[0] = -(DATA_TYPE)36.0f * in[1].v + (DATA_TYPE)0.0f * in[2].v + (DATA_TYPE)49.0f * in[3].v - (DATA_TYPE)14.0f * in[5].v + in[7].v;
 
     TILE(uint, 8, 1, dst_indirect_y);
 
@@ -554,20 +566,20 @@
 
     LOOP_UNROLLING(int, i, 0, 1, 8,
     {
-        com[0].s[0]         = 36.0f * tmp[i].s[2] - 13.0f * tmp[i].s[4] + tmp[i].s[6];
-        com[0].s[1]         = 36.0f * tmp[i].s[1] - 13.0f * tmp[i].s[3] + 1.0f * tmp[i].s[5];
-        com[0].s[2]         = 9.0f * tmp[i].s[2] - 10.0f * tmp[i].s[4] + tmp[i].s[6];
-        com[0].s[3]         = 18.0f * tmp[i].s[1] - 20.0f * tmp[i].s[3] + 2.0f * tmp[i].s[5];
-        com[0].s[4]         = 4.0f * tmp[i].s[2] - 5.0f * tmp[i].s[4] + tmp[i].s[6];
-        com[0].s[5]         = 12.0f * tmp[i].s[1] - 15.0f * tmp[i].s[3] + 3.0f * tmp[i].s[5];
-        out[i * 8 + 0].s[0] = -36.0f * tmp[i].s[0] + 49.0f * tmp[i].s[2] + -14.0f * tmp[i].s[4] + tmp[i].s[6];
+        com[0].s[0]         = (DATA_TYPE)36.0f * tmp[i].s[2] - (DATA_TYPE)13.0f * tmp[i].s[4] + tmp[i].s[6];
+        com[0].s[1]         = (DATA_TYPE)36.0f * tmp[i].s[1] - (DATA_TYPE)13.0f * tmp[i].s[3] + (DATA_TYPE)1.0f * tmp[i].s[5];
+        com[0].s[2]         = (DATA_TYPE)9.0f * tmp[i].s[2] - (DATA_TYPE)10.0f * tmp[i].s[4] + tmp[i].s[6];
+        com[0].s[3]         = (DATA_TYPE)18.0f * tmp[i].s[1] - (DATA_TYPE)20.0f * tmp[i].s[3] + (DATA_TYPE)2.0f * tmp[i].s[5];
+        com[0].s[4]         = (DATA_TYPE)4.0f * tmp[i].s[2] - (DATA_TYPE)5.0f * tmp[i].s[4] + tmp[i].s[6];
+        com[0].s[5]         = (DATA_TYPE)12.0f * tmp[i].s[1] - (DATA_TYPE)15.0f * tmp[i].s[3] + (DATA_TYPE)3.0f * tmp[i].s[5];
+        out[i * 8 + 0].s[0] = (DATA_TYPE) - 36.0f * tmp[i].s[0] + (DATA_TYPE)49.0f * tmp[i].s[2] + -(DATA_TYPE)14.0f * tmp[i].s[4] + tmp[i].s[6];
         out[i * 8 + 1].s[0] = com[0].s[0] - com[0].s[1];
         out[i * 8 + 2].s[0] = com[0].s[0] + com[0].s[1];
         out[i * 8 + 3].s[0] = com[0].s[2] - com[0].s[3];
         out[i * 8 + 4].s[0] = com[0].s[2] + com[0].s[3];
         out[i * 8 + 5].s[0] = com[0].s[4] - com[0].s[5];
         out[i * 8 + 6].s[0] = com[0].s[4] + com[0].s[5];
-        out[i * 8 + 7].s[0] = -36.0f * tmp[i].s[1] + 0.0f * tmp[i].s[2] + 49.0f * tmp[i].s[3] - 14.0f * tmp[i].s[5] + tmp[i].s[7];
+        out[i * 8 + 7].s[0] = -(DATA_TYPE)36.0f * tmp[i].s[1] + (DATA_TYPE)0.0f * tmp[i].s[2] + (DATA_TYPE)49.0f * tmp[i].s[3] - (DATA_TYPE)14.0f * tmp[i].s[5] + tmp[i].s[7];
     })
 
     TILE(uint, 64, 1, dst_indirect_y);