COMPMID-1009 Support 4x4 output tile for Winograd Filter Transform on OpenCL.

Change-Id: I68c6453e0f192de659582404f109a89616b9fbb9
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/124811
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp
index c760663..ad0dcbd 100644
--- a/tests/validation/reference/Winograd.cpp
+++ b/tests/validation/reference/Winograd.cpp
@@ -39,40 +39,74 @@
 namespace
 {
 template <typename T>
-void winograd_filter_transform3x3(const SimpleTensor<T> &in, SimpleTensor<T> &out)
+void winograd_filter_transform3x3(const SimpleTensor<T> &in, SimpleTensor<T> &out, const Size2D &output_tile)
 {
+    const bool         is_2x2      = (output_tile.width == 2);
+    const unsigned int transf_side = is_2x2 ? 4u : 6u;
+
     // Simple tensor for the 3x3 input tile
     SimpleTensor<T> input_tile{ TensorShape(3u, 3u), in.data_type(), 1 };
 
     // Simple tensor for the transformation matrix
-    SimpleTensor<T> trans_matrix{ TensorShape(3u, 4u), in.data_type(), 1 };
+    SimpleTensor<T> trans_matrix{ TensorShape(3u, transf_side), in.data_type(), 1 };
 
     // Simple tensor for the transformation matrix transpose
-    SimpleTensor<T> trans_matrix_transposed{ TensorShape(4u, 3u), in.data_type(), 1 };
+    SimpleTensor<T> trans_matrix_transposed{ TensorShape(transf_side, 3u), in.data_type(), 1 };
 
-    // Simple tensor for the 4x3 temporary tile
-    SimpleTensor<T> tmp_tile{ TensorShape(3u, 4u), in.data_type(), 1 };
+    // Simple tensor for the 3xSide temporary tile
+    SimpleTensor<T> tmp_tile{ TensorShape(3u, transf_side), in.data_type(), 1 };
 
-    // Simple tensor for the 4x4 output tile
-    SimpleTensor<T> output_tile{ TensorShape(4u, 4u), in.data_type(), 1 };
+    // Simple tensor for the SidexSide output tile
+    SimpleTensor<T> transf_tile{ TensorShape(transf_side, transf_side), in.data_type(), 1 };
 
-    // Initialize transformation matrix
-    // 1   | 0   | 0
-    // 0.5 | 0.5 | 0.5
-    // 0.5 |-0.5 | 0.5
-    // 0   | 0   | 1
-    trans_matrix[0 + 0 * 3] = 1.0f;
-    trans_matrix[1 + 0 * 3] = 0.0f;
-    trans_matrix[2 + 0 * 3] = 0.0f;
-    trans_matrix[0 + 1 * 3] = 0.5f;
-    trans_matrix[1 + 1 * 3] = 0.5f;
-    trans_matrix[2 + 1 * 3] = 0.5f;
-    trans_matrix[0 + 2 * 3] = 0.5f;
-    trans_matrix[1 + 2 * 3] = -0.5f;
-    trans_matrix[2 + 2 * 3] = 0.5f;
-    trans_matrix[0 + 3 * 3] = 0.0f;
-    trans_matrix[1 + 3 * 3] = 0.0f;
-    trans_matrix[2 + 3 * 3] = 1.0f;
+    if(is_2x2)
+    {
+        // Initialize 3x4 transformation matrix
+        // 1   | 0   | 0
+        // 0.5 | 0.5 | 0.5
+        // 0.5 |-0.5 | 0.5
+        // 0   | 0   | 1
+        trans_matrix[0 + 0 * 3] = 1.0f;
+        trans_matrix[1 + 0 * 3] = 0.0f;
+        trans_matrix[2 + 0 * 3] = 0.0f;
+        trans_matrix[0 + 1 * 3] = 0.5f;
+        trans_matrix[1 + 1 * 3] = 0.5f;
+        trans_matrix[2 + 1 * 3] = 0.5f;
+        trans_matrix[0 + 2 * 3] = 0.5f;
+        trans_matrix[1 + 2 * 3] = -0.5f;
+        trans_matrix[2 + 2 * 3] = 0.5f;
+        trans_matrix[0 + 3 * 3] = 0.0f;
+        trans_matrix[1 + 3 * 3] = 0.0f;
+        trans_matrix[2 + 3 * 3] = 1.0f;
+    }
+    else
+    {
+        // Initialize 3x6 transformation matrix
+        //   1/4  |    0   |   0
+        //  -1/6  |  -1/6  | -1/6
+        //  -1/6  |   1/6  | -1/6
+        //  1/24  |  1/12  |  1/6
+        //  1/24  | -1/12  |  1/6
+        //    0   |    0   |   1
+        trans_matrix[0 + 0 * 3] = 1.0f / 4.0f;
+        trans_matrix[1 + 0 * 3] = 0.0f;
+        trans_matrix[2 + 0 * 3] = 0.0f;
+        trans_matrix[0 + 1 * 3] = -1.0f / 6.0f;
+        trans_matrix[1 + 1 * 3] = -1.0f / 6.0f;
+        trans_matrix[2 + 1 * 3] = -1.0f / 6.0f;
+        trans_matrix[0 + 2 * 3] = -1.0f / 6.0f;
+        trans_matrix[1 + 2 * 3] = 1.0f / 6.0f;
+        trans_matrix[2 + 2 * 3] = -1.0f / 6.0f;
+        trans_matrix[0 + 3 * 3] = 1.0f / 24.0f;
+        trans_matrix[1 + 3 * 3] = 1.0f / 12.0f;
+        trans_matrix[2 + 3 * 3] = 1.0f / 6.0f;
+        trans_matrix[0 + 4 * 3] = 1.0f / 24.0f;
+        trans_matrix[1 + 4 * 3] = -1.0f / 12.0f;
+        trans_matrix[2 + 4 * 3] = 1.0f / 6.0f;
+        trans_matrix[0 + 5 * 3] = 0.0f;
+        trans_matrix[1 + 5 * 3] = 0.0f;
+        trans_matrix[2 + 5 * 3] = 1.0f;
+    }
 
     // Transpose the transformation matrix
     transpose_matrix(trans_matrix, trans_matrix_transposed);
@@ -94,26 +128,18 @@
                 matrix_multiply(trans_matrix, input_tile, tmp_tile);
 
                 // Second transformation
-                matrix_multiply(tmp_tile, trans_matrix_transposed, output_tile);
+                matrix_multiply(tmp_tile, trans_matrix_transposed, transf_tile);
 
                 // Store the 4x4 output tile across the 16 channels
-                const int output_offset                              = w + z * num_filters;
-                out[output_offset + 0 * num_filters * num_channels]  = output_tile[0 + 0 * 4];
-                out[output_offset + 1 * num_filters * num_channels]  = output_tile[1 + 0 * 4];
-                out[output_offset + 2 * num_filters * num_channels]  = output_tile[2 + 0 * 4];
-                out[output_offset + 3 * num_filters * num_channels]  = output_tile[3 + 0 * 4];
-                out[output_offset + 4 * num_filters * num_channels]  = output_tile[0 + 1 * 4];
-                out[output_offset + 5 * num_filters * num_channels]  = output_tile[1 + 1 * 4];
-                out[output_offset + 6 * num_filters * num_channels]  = output_tile[2 + 1 * 4];
-                out[output_offset + 7 * num_filters * num_channels]  = output_tile[3 + 1 * 4];
-                out[output_offset + 8 * num_filters * num_channels]  = output_tile[0 + 2 * 4];
-                out[output_offset + 9 * num_filters * num_channels]  = output_tile[1 + 2 * 4];
-                out[output_offset + 10 * num_filters * num_channels] = output_tile[2 + 2 * 4];
-                out[output_offset + 11 * num_filters * num_channels] = output_tile[3 + 2 * 4];
-                out[output_offset + 12 * num_filters * num_channels] = output_tile[0 + 3 * 4];
-                out[output_offset + 13 * num_filters * num_channels] = output_tile[1 + 3 * 4];
-                out[output_offset + 14 * num_filters * num_channels] = output_tile[2 + 3 * 4];
-                out[output_offset + 15 * num_filters * num_channels] = output_tile[3 + 3 * 4];
+                const int output_offset = w + z * num_filters;
+
+                for(unsigned int out_h = 0, out_pos = 0; out_h < transf_side; ++out_h)
+                {
+                    for(unsigned int out_w = 0; out_w < transf_side; ++out_w, ++out_pos)
+                    {
+                        out[output_offset + out_pos * num_filters * num_channels] = transf_tile[out_w + out_h * transf_side];
+                    }
+                }
             }
         }
     }
@@ -314,7 +340,7 @@
 }
 
 template <typename T>
-SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const TensorShape &output_shape)
+SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const Size2D &output_tile)
 {
     ARM_COMPUTE_ERROR_ON_MSG(in.data_layout() != DataLayout::NCHW, "Only supported NCHW data format");
 
@@ -324,7 +350,7 @@
     switch(in.shape()[0])
     {
         case 3:
-            winograd_filter_transform3x3(in, out);
+            winograd_filter_transform3x3(in, out, output_tile);
             break;
         default:
             ARM_COMPUTE_ERROR("Only supported 3x3 kernel");
@@ -358,7 +384,7 @@
 }
 
 template SimpleTensor<float> winograd_input_transform(const SimpleTensor<float> &src, const TensorShape &dst_shape, const PadStrideInfo &conv_info, const Size2D &kernel_dims);
-template SimpleTensor<float> winograd_filter_transform(const SimpleTensor<float> &in, const TensorShape &output_shape);
+template SimpleTensor<float> winograd_filter_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const Size2D &output_tile);
 template SimpleTensor<float> winograd_output_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const Size2D &kernel_dims, const Size2D &num_tiles);
 } // namespace reference
 } // namespace validation