COMPMID-1096 - Add fast_math flag to CLConvolutionLayer
COMPMID-1103 - CLWinogradConvolutionLayer mismatches

Change-Id: Iceaa9482a1790ec39d2720c220261aaea8043978
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129398
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/tests/datasets/LargeConvolutionLayerDataset.h b/tests/datasets/LargeConvolutionLayerDataset.h
index ec8e09f..36b3d60 100644
--- a/tests/datasets/LargeConvolutionLayerDataset.h
+++ b/tests/datasets/LargeConvolutionLayerDataset.h
@@ -59,6 +59,25 @@
     }
 };
 
+class LargeWinogradConvolutionLayer5x5Dataset final : public ConvolutionLayerDataset
+{
+public:
+    LargeWinogradConvolutionLayer5x5Dataset()
+    {
+        // Kernel size 5
+        // Batch size 1
+        add_config(TensorShape(224U, 224U, 3U), TensorShape(5U, 5U, 3U, 64U), TensorShape(64U), TensorShape(222U, 222U, 64U), PadStrideInfo(1, 1, 1, 1));
+        add_config(TensorShape(123U, 134U, 16U), TensorShape(5U, 5U, 16U, 7U), TensorShape(7U), TensorShape(121U, 130U, 7U), PadStrideInfo(1, 1, 1, 0));
+        add_config(TensorShape(181U, 152U, 42U), TensorShape(5U, 5U, 42U, 100U), TensorShape(100U), TensorShape(177U, 148U, 100U), PadStrideInfo(1, 1, 0, 0));
+        add_config(TensorShape(200U, 201U, 24U), TensorShape(5U, 5U, 24U, 61), TensorShape(61U), TensorShape(200U, 201U, 61), PadStrideInfo(1, 1, 2, 2));
+
+        // Batch size 2, 3 and 4
+        add_config(TensorShape(224U, 224U, 3U, 2U), TensorShape(5U, 5U, 3U, 64U), TensorShape(64U), TensorShape(222U, 222U, 64U, 2U), PadStrideInfo(1, 1, 1, 1));
+        add_config(TensorShape(123U, 134U, 16U, 3U), TensorShape(5U, 5U, 16U, 7U), TensorShape(7U), TensorShape(121U, 130U, 7U, 3U), PadStrideInfo(1, 1, 1, 0));
+        add_config(TensorShape(181U, 152U, 42U, 4U), TensorShape(5U, 5U, 42U, 100U), TensorShape(100U), TensorShape(177U, 148U, 100U, 4U), PadStrideInfo(1, 1, 0, 0));
+    }
+};
+
 class LargeConvolutionLayerDataset final : public ConvolutionLayerDataset
 {
 public:
diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp
index 8685e5b..a2b55a8 100644
--- a/tests/validation/CL/ConvolutionLayer.cpp
+++ b/tests/validation/CL/ConvolutionLayer.cpp
@@ -73,44 +73,72 @@
 TEST_SUITE(CL)
 TEST_SUITE(ConvolutionLayer)
 
-DATA_TEST_CASE(ValidateConvolutionMethod, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(
-                                                                                           framework::dataset::make("InputInfo", { TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(23U, 27U, 5U, 4U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(3U, 3U, 2U, 1U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(33U, 27U, 7U, 4U), 1, DataType::F32, 0)
-                                                                                                                                 }),
-                                                                                           framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0),
-                                                                                                                    TensorInfo(TensorShape(5U, 5U, 7U, 16U), 1, DataType::F16, 0)
-                                                                                                                                   })),
-                                                                                       framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0),
-                                                                                                                TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0),
-                                                                                                                TensorInfo(TensorShape(21U, 25U, 21U, 4U), 1, DataType::F32, 0),
-                                                                                                                TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32, 0),
-                                                                                                                TensorInfo(TensorShape(11U, 12U, 16U, 4U), 1, DataType::F32, 0)
-                                                                                                                              })),
-                                                                                   framework::dataset::make("ConvInfo", { PadStrideInfo(1, 2, 1, 1),
-                                                                                                            PadStrideInfo(1, 2, 1, 1),
-                                                                                                            PadStrideInfo(1, 1, 0, 0),
-                                                                                                            PadStrideInfo(2, 1, 0, 0),
-                                                                                                            PadStrideInfo(3, 2, 1, 0)
-                                                                                                                        })),
-                                                                               framework::dataset::make("GpuTarget", { GPUTarget::BIFROST,
-                                                                                                                       GPUTarget::MIDGARD,
-                                                                                                                       GPUTarget::G71,
-                                                                                                                       GPUTarget::MIDGARD,
-                                                                                                                       GPUTarget::BIFROST
-                                                                                                                     })),
-
-                                                                           framework::dataset::make("Expected", { ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM, ConvolutionMethod::GEMM })),
-               input_info, weights_info, output_info, conv_info, gpu_target, expected)
+DATA_TEST_CASE(ValidateConvolutionMethod, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
+                                                                                                   framework::dataset::make("InputInfo", { TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(23U, 27U, 5U, 4U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(3U, 3U, 2U, 1U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(33U, 27U, 7U, 4U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0)
+                                                                                                                                         }),
+                                                                                                   framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(5U, 5U, 7U, 16U), 1, DataType::F16, 0),
+                                                                                                           TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0),
+                                                                                                           TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0)
+                                                                                                                                           })),
+                                                                                               framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0),
+                                                                                                                        TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0),
+                                                                                                                        TensorInfo(TensorShape(21U, 25U, 21U, 4U), 1, DataType::F32, 0),
+                                                                                                                        TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32, 0),
+                                                                                                                        TensorInfo(TensorShape(11U, 12U, 16U, 4U), 1, DataType::F32, 0),
+                                                                                                                        TensorInfo(TensorShape(17U, 31U, 19U), 1, DataType::F32, 0),
+                                                                                                                        TensorInfo(TensorShape(17U, 31U, 19U), 1, DataType::F32, 0)
+                                                                                                                                      })),
+                                                                                           framework::dataset::make("ConvInfo", { PadStrideInfo(1, 2, 1, 1),
+                                                                                                                    PadStrideInfo(1, 2, 1, 1),
+                                                                                                                    PadStrideInfo(1, 1, 0, 0),
+                                                                                                                    PadStrideInfo(2, 1, 0, 0),
+                                                                                                                    PadStrideInfo(3, 2, 1, 0),
+                                                                                                                    PadStrideInfo(1, 1, 2, 2),
+                                                                                                                    PadStrideInfo(1, 1, 2, 2)
+                                                                                                                                })),
+                                                                                       framework::dataset::make("GpuTarget", { GPUTarget::BIFROST,
+                                                                                                                GPUTarget::MIDGARD,
+                                                                                                                GPUTarget::G71,
+                                                                                                                GPUTarget::MIDGARD,
+                                                                                                                GPUTarget::BIFROST,
+                                                                                                                GPUTarget::BIFROST,
+                                                                                                                GPUTarget::BIFROST
+                                                                                                                             })),
+                                                                                   framework::dataset::make("Dilation",
 {
-    ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(false),
-                                                                            &weights_info.clone()->set_is_resizable(false),
-                                                                            &output_info.clone()->set_is_resizable(false), conv_info, WeightsInfo(), ActivationLayerInfo(), gpu_target);
+    Size2D(1U, 1U),
+    Size2D(1U, 1U),
+    Size2D(1U, 1U),
+    Size2D(1U, 1U),
+    Size2D(1U, 1U),
+    Size2D(1U, 1U),
+    Size2D(2U, 1U),
+})),
+framework::dataset::make("EnableFastMath", { false, false, false, false, false, true, true })),
+framework::dataset::make("Expected",
+{
+    ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM,
+})),
+input_info, weights_info, output_info, conv_info, gpu_target, dilation, enable_fast_math, expected)
+{
+    ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(true),
+                                                                            &weights_info.clone()->set_is_resizable(true),
+                                                                            &output_info.clone()->set_is_resizable(true), conv_info,
+                                                                            WeightsInfo(),
+                                                                            ActivationLayerInfo(),
+                                                                            gpu_target,
+                                                                            dilation,
+                                                                            enable_fast_math);
     ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
 }
 TEST_SUITE_END()
diff --git a/tests/validation/CL/DilatedConvolutionLayer.cpp b/tests/validation/CL/DilatedConvolutionLayer.cpp
index e6a765b..9ee002c 100644
--- a/tests/validation/CL/DilatedConvolutionLayer.cpp
+++ b/tests/validation/CL/DilatedConvolutionLayer.cpp
@@ -104,9 +104,9 @@
                                                                            framework::dataset::make("Expected", { ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM, ConvolutionMethod::GEMM })),
                input_info, weights_info, output_info, conv_info, gpu_target, dilation, expected)
 {
-    ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(false),
-                                                                            &weights_info.clone()->set_is_resizable(false),
-                                                                            &output_info.clone()->set_is_resizable(false), conv_info, WeightsInfo(), ActivationLayerInfo(), gpu_target, dilation);
+    ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(true),
+                                                                            &weights_info.clone()->set_is_resizable(true),
+                                                                            &output_info.clone()->set_is_resizable(true), conv_info, WeightsInfo(), ActivationLayerInfo(), gpu_target, dilation);
     ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
 }
 TEST_SUITE_END()
diff --git a/tests/validation/CL/Winograd.cpp b/tests/validation/CL/Winograd.cpp
index 30d8d75..d892c9f 100644
--- a/tests/validation/CL/Winograd.cpp
+++ b/tests/validation/CL/Winograd.cpp
@@ -51,7 +51,8 @@
 {
 namespace
 {
-constexpr AbsoluteTolerance<float> tolerance_f32(0.001f);
+constexpr AbsoluteTolerance<float> tolerance_f32(0.0001f);
+constexpr AbsoluteTolerance<float> tolerance_fast_math_f32(0.1f);
 } // namespace
 
 using namespace arm_compute::misc::shape_calculator;
@@ -379,6 +380,27 @@
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
+TEST_SUITE(EnableFastMath)
+using CLWinogradConvolutionLayerFastMathFixture = WinogradConvolutionLayerFastMathValidationFixture<CLTensor, CLAccessor, CLWinogradConvolutionLayer, float>;
+FIXTURE_DATA_TEST_CASE(RunSmall, CLWinogradConvolutionLayerFastMathFixture, framework::DatasetMode::PRECOMMIT,
+                       combine(combine(framework::dataset::concat(datasets::SmallWinogradConvolutionLayer3x3Dataset(), datasets::SmallWinogradConvolutionLayer5x5Dataset()),
+                                       framework::dataset::make("DataType", { DataType::F32 })),
+                               framework::dataset::make("ActivationLayerInfo", { ActivationLayerInfo() })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_fast_math_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture, framework::DatasetMode::NIGHTLY,
+                       combine(combine(framework::dataset::concat(datasets::LargeWinogradConvolutionLayer3x3Dataset(), datasets::LargeWinogradConvolutionLayer5x5Dataset()),
+                                       framework::dataset::make("DataType", { DataType::F32 })),
+                               framework::dataset::make("ActivationLayerInfo", { ActivationLayerInfo() })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_fast_math_f32);
+}
+
+TEST_SUITE_END() // EnableFastMath
 TEST_SUITE_END() // ConvolutionLayer
 
 TEST_SUITE_END() // Winograd
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
index 249f9d5..e15931e 100644
--- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
@@ -35,6 +35,7 @@
 #include "tests/validation/Helpers.h"
 #include "tests/validation/reference/ActivationLayer.h"
 #include "tests/validation/reference/ConvolutionLayer.h"
+#include "tests/validation/reference/GEMM.h"
 #include "tests/validation/reference/Utils.h"
 #include "tests/validation/reference/Winograd.h"
 
@@ -153,6 +154,123 @@
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class WinogradConvolutionLayerFastMathValidationFixture : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataType data_type, ActivationLayerInfo act_info)
+    {
+        ARM_COMPUTE_UNUSED(dilation);
+
+        _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info);
+        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info);
+    }
+
+protected:
+    template <typename U>
+    void fill(U &&tensor, int i, float min, float max)
+    {
+        switch(tensor.data_type())
+        {
+            case DataType::F32:
+            {
+                std::uniform_real_distribution<> distribution(min, max);
+                library->fill(tensor, distribution, i);
+                break;
+            }
+            default:
+            {
+                ARM_COMPUTE_ERROR("Not supported");
+                library->fill_tensor_uniform(tensor, i);
+                break;
+            }
+        }
+    }
+
+    TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
+                              DataType data_type, ActivationLayerInfo act_info)
+    {
+        // Create tensors
+        TensorType src     = create_tensor<TensorType>(input_shape, data_type, 1);
+        TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1);
+        TensorType bias    = create_tensor<TensorType>(bias_shape, data_type, 1);
+        TensorType dst     = create_tensor<TensorType>(output_shape, data_type, 1);
+
+        // Create and configure function
+        FunctionType conv;
+        ARM_COMPUTE_EXPECT(static_cast<bool>(conv.validate(src.info(), weights.info(), bias.info(), dst.info(), info, act_info, true /* Enable fast math */)), framework::LogLevel::ERRORS);
+        conv.configure(&src, &weights, &bias, &dst, info, act_info, true /* Enable fast math */);
+
+        ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+        // Allocate tensors
+        src.allocator()->allocate();
+        weights.allocator()->allocate();
+        dst.allocator()->allocate();
+        bias.allocator()->allocate();
+
+        ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!weights.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+        // Fill tensors
+        fill(AccessorType(src), 0, -1.f, 1.f);
+        fill(AccessorType(weights), 1, -1.f, 1.f);
+        fill(AccessorType(bias), 2, -1.f, 1.f);
+
+        // Compute Winograd Convolution function
+        conv.run();
+
+        return dst;
+    }
+
+    SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
+                                      DataType data_type, ActivationLayerInfo act_info)
+    {
+        // Create reference
+        SimpleTensor<T> src{ input_shape, data_type, 1 };
+        SimpleTensor<T> weights{ weights_shape, data_type, 1 };
+        SimpleTensor<T> bias{ bias_shape, data_type, 1 };
+
+        // Fill reference
+        fill(src, 0, -1.f, 1.f);
+        fill(weights, 1, -1.f, 1.f);
+        fill(bias, 2, -1.f, 1.f);
+
+        WinogradInfo winograd_info(Size2D(4U, 4U),
+                                   Size2D(weights_shape[0], weights_shape[1]),
+                                   Size2D(input_shape[0], input_shape[1]),
+                                   info,
+                                   src.data_layout());
+
+        // Compute tensor shapes for input, filter and output transforms
+        TensorShape input_transform_shape  = compute_winograd_input_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
+        TensorShape filter_transform_shape = compute_winograd_filter_transform_shape(TensorInfo(weights_shape, 1, data_type), winograd_info);
+        TensorShape batched_gemm_shape     = input_transform_shape;
+        batched_gemm_shape[0]              = filter_transform_shape[0];
+        TensorShape output_transform_shape = compute_winograd_output_transform_shape(TensorInfo(batched_gemm_shape, 1, data_type), winograd_info);
+
+        // Dummy matrix C to perform matrix multiplication
+        SimpleTensor<T> dummy_c{ batched_gemm_shape, data_type, 1 };
+
+        // Compute Winograd-based convolution
+        SimpleTensor<T> input_transform_out  = reference::winograd_input_transform<T>(src, input_transform_shape, winograd_info);
+        SimpleTensor<T> filter_transform_out = reference::winograd_filter_transform<T>(weights, filter_transform_shape, winograd_info);
+        SimpleTensor<T> batched_gemm         = reference::gemm<T>(input_transform_out, filter_transform_out, dummy_c, 1.0f, 0.0f);
+        SimpleTensor<T> conv_out             = reference::winograd_output_transform<T>(batched_gemm, bias, output_transform_shape, winograd_info);
+
+        return (act_info.enabled()) ? reference::activation_layer<T>(conv_out, act_info) : conv_out;
+    }
+
+    TensorType      _target{};
+    SimpleTensor<T> _reference{};
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
 class WinogradInputTransformValidationFixture : public framework::Fixture
 {
 public:
@@ -373,11 +491,13 @@
     {
         // Create reference
         SimpleTensor<T> src{ input_shape, data_type };
+        SimpleTensor<T> bias{ TensorShape(input_shape[0]), data_type };
 
         // Fill reference
         fill(src, 0, -1.f, 1.f);
+        fill(bias, 1, 0.0f, 0.0f); // Fill with zeros as we validate just the output transform without bias contribution
 
-        return reference::winograd_output_transform<T>(src, output_shape, winograd_info);
+        return reference::winograd_output_transform<T>(src, bias, output_shape, winograd_info);
     }
 
     TensorType      _target{};
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index 77d025e..f9dcfcb 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -41,23 +41,44 @@
     SimpleTensor<T> dst{ c.shape(), c.data_type(), 1, c.fixed_point_position() };
 
     // Compute reference
-    const int M = dst.shape().y();
-    const int N = dst.shape().x();
+    const int M = a.shape().y();
+    const int N = b.shape().x();
     const int K = a.shape().x();
+    const int D = a.shape().z(); // Number of matrices in a batch
+    const int W = a.shape()[3];  // Number of batched-gemm (Winograd case)
 
-    for(int row = 0; row < M; ++row)
+    const int a_stride_z = K * M;
+    const int a_stride_w = K * M * D;
+
+    const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0;     // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
+    const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
+
+    const int c_stride_z = N * M;
+    const int c_stride_w = N * M * D;
+
+    for(int w = 0; w < W; ++w)
     {
-        for(int col = 0; col < N; ++col)
+        for(int depth = 0; depth < D; ++depth)
         {
-            T acc(0);
+            const int base_addr_a = depth * a_stride_z + w * a_stride_w;
+            const int base_addr_b = depth * b_stride_z + w * b_stride_w;
+            const int base_addr_c = depth * c_stride_z + w * c_stride_w;
 
-            for(int k = 0; k < K; ++k)
+            for(int row = 0; row < M; ++row)
             {
-                acc += a[row * K + k] * b[k * N + col];
-            }
+                for(int col = 0; col < N; ++col)
+                {
+                    T acc(0);
 
-            // Finalize the result: alpha * A * B + beta * C
-            dst[col + row * N] = alpha * acc + beta * c[col + row * N];
+                    for(int k = 0; k < K; ++k)
+                    {
+                        acc += a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N];
+                    }
+
+                    // Finalize the result: alpha * A * B + beta * C
+                    dst[base_addr_c + col + row * N] = alpha * acc + beta * c[base_addr_c + col + row * N];
+                }
+            }
         }
     }
 
@@ -75,37 +96,58 @@
     // Compute reference
     using promoted_type = fixed_point_arithmetic::traits::promote_t<T>;
 
-    const int M                    = dst.shape().y();
-    const int N                    = dst.shape().x();
-    const int K                    = a.shape().x();
-    const int fixed_point_position = a.fixed_point_position();
+    const int M = dst.shape().y();
+    const int N = dst.shape().x();
+    const int K = a.shape().x();
+    const int D = a.shape().z(); // Number of matrices in a batch
+    const int W = a.shape()[3];  // Number of batched-gemm (Winograd case)
 
+    const int a_stride_z = K * M;
+    const int a_stride_w = K * M * D;
+
+    const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0;     // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
+    const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
+
+    const int c_stride_z = N * M;
+    const int c_stride_w = N * M * D;
+
+    const int            fixed_point_position = a.fixed_point_position();
     const fixed_point<T> alpha_q(alpha, fixed_point_position);
     const fixed_point<T> beta_q(beta, fixed_point_position);
 
-    for(int row = 0; row < M; ++row)
+    for(int w = 0; w < W; ++w)
     {
-        for(int col = 0; col < N; ++col)
+        for(int depth = 0; depth < D; ++depth)
         {
-            fixed_point<promoted_type> acc_q(0, fixed_point_position);
+            const int base_addr_a = depth * a_stride_z + w * a_stride_w;
+            const int base_addr_b = depth * b_stride_z + w * b_stride_w;
+            const int base_addr_c = depth * c_stride_z + w * c_stride_w;
 
-            for(int k = 0; k < K; ++k)
+            for(int row = 0; row < M; ++row)
             {
-                const fixed_point<promoted_type> a0_q(a[row * K + k], fixed_point_position, true);
-                const fixed_point<promoted_type> b0_q(b[k * N + col], fixed_point_position, true);
+                for(int col = 0; col < N; ++col)
+                {
+                    fixed_point<promoted_type> acc_q(0, fixed_point_position);
 
-                acc_q = acc_q + (a0_q * b0_q);
+                    for(int k = 0; k < K; ++k)
+                    {
+                        const fixed_point<promoted_type> a0_q(a[base_addr_a + row * K + k], fixed_point_position, true);
+                        const fixed_point<promoted_type> b0_q(b[base_addr_b + k * N + col], fixed_point_position, true);
+
+                        acc_q = acc_q + (a0_q * b0_q);
+                    }
+
+                    // Finalize the result: alpha * A * B + beta * C
+                    const fixed_point<T> c0_q(c[base_addr_c + col + row * N], fixed_point_position, true);
+
+                    fixed_point<T> res_q(acc_q);
+                    res_q = alpha_q * res_q;
+                    res_q = res_q + (beta_q * c0_q);
+
+                    // Store the result
+                    dst[base_addr_c + col + row * N] = res_q.raw();
+                }
             }
-
-            // Finalize the result: alpha * A * B + beta * C
-            const fixed_point<T> c0_q(c[col + row * N], fixed_point_position, true);
-
-            fixed_point<T> res_q(acc_q);
-            res_q = alpha_q * res_q;
-            res_q = res_q + (beta_q * c0_q);
-
-            // Store the result
-            dst[col + row * N] = res_q.raw();
         }
     }
 
diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp
index 75b1b51..194a78e 100644
--- a/tests/validation/reference/Winograd.cpp
+++ b/tests/validation/reference/Winograd.cpp
@@ -331,7 +331,7 @@
 }
 
 template <typename T>
-SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info)
+SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const SimpleTensor<T> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info)
 {
     ARM_COMPUTE_ERROR_ON_MSG(winograd_info.output_data_layout != DataLayout::NCHW, "Only supported NCHW data format");
 
@@ -444,6 +444,9 @@
                         if((xo + xi < w_out) && (yo + yi < h_out))
                         {
                             out[output_offset + yi * stridey_out + xi] = output_tile[xi + yi * out_tile_w];
+
+                            // Add bias
+                            out[output_offset + yi * stridey_out + xi] += b[zo];
                         }
                     }
                 }
@@ -456,7 +459,7 @@
 
 template SimpleTensor<float> winograd_filter_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
 template SimpleTensor<float> winograd_input_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
-template SimpleTensor<float> winograd_output_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
+template SimpleTensor<float> winograd_output_transform(const SimpleTensor<float> &in, const SimpleTensor<float> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/reference/Winograd.h b/tests/validation/reference/Winograd.h
index 29181f1..b74c2c3 100644
--- a/tests/validation/reference/Winograd.h
+++ b/tests/validation/reference/Winograd.h
@@ -51,7 +51,7 @@
 SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
 
 template <typename T>
-SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
+SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const SimpleTensor<T> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info);
 } // namespace reference
 } // namespace validation
 } // namespace test