diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp
index 2d31405..f41a6fc 100644
--- a/tests/validation/reference/ConvolutionLayer.cpp
+++ b/tests/validation/reference/ConvolutionLayer.cpp
@@ -47,8 +47,10 @@
 
 template <typename T, typename TB>
 SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info,
-                                       const Size2D &dilation)
+                                       const Size2D &dilation, unsigned int num_groups)
 {
+    ARM_COMPUTE_ERROR_ON((src.shape()[2] / num_groups) != weights.shape()[2]);
+
     // Compute reference
     const int width_in       = src.shape().x();
     const int height_in      = src.shape().y();
@@ -78,23 +80,28 @@
         {
             for(int xi = start_xi; xi < start_xi + end_xi; xi += stride_xi)
             {
-                for(int ofm = 0; ofm < depth_out; ++ofm)
+                for(int group = 0; group < static_cast<int>(num_groups); ++group)
                 {
-                    // Compute input and output offsets
-                    const int offset_in  = r * width_in * height_in * depth_in;
-                    const int xo         = (xi - start_xi) / stride_xi;
-                    const int yo         = (yi - start_yi) / stride_yi;
-                    const int offset_out = xo + yo * width_out + ofm * width_out * height_out + r * width_out * height_out * depth_out;
+                    for(int ofm = 0; ofm < static_cast<int>(depth_out / num_groups); ++ofm)
+                    {
+                        // Compute input and output offsets
+                        const int offset_in  = r * width_in * height_in * depth_in + (group * (depth_in / num_groups) * width_in * height_in);
+                        const int xo         = (xi - start_xi) / stride_xi;
+                        const int yo         = (yi - start_yi) / stride_yi;
+                        const int offset_out = xo + yo * width_out + ((ofm + group * (depth_out / num_groups)) * width_out * height_out) + (r * width_out * height_out * depth_out);
+                        const int offset_w   = (ofm + group * (depth_out / num_groups)) * width_weights * height_weights * depth_weights;
+                        const int offset_b   = (ofm + group * (depth_out / num_groups));
 
-                    ARM_COMPUTE_ASSERT(xo < width_out);
-                    ARM_COMPUTE_ASSERT(yo < height_out);
+                        ARM_COMPUTE_ASSERT(xo < width_out);
+                        ARM_COMPUTE_ASSERT(yo < height_out);
 
-                    // Compute 3D convolution
-                    convolution_3d::detail::convolution3d(src, weights, bias, dst,
-                                                          offset_in, ofm * width_weights * height_weights * depth_weights, ofm, offset_out,
-                                                          xi, yi,
-                                                          width_in, height_in, depth_in,
-                                                          width_weights, height_weights, dilation.x(), dilation.y());
+                        // Compute 3D convolution
+                        convolution_3d::detail::convolution3d(src, weights, bias, dst,
+                                                              offset_in, offset_w, offset_b, offset_out,
+                                                              xi, yi,
+                                                              width_in, height_in, (depth_in / num_groups),
+                                                              width_weights, height_weights, dilation.x(), dilation.y());
+                    }
                 }
             }
         }
@@ -104,7 +111,7 @@
 }
 template <typename T, typename TB>
 SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info,
-                                  const Size2D &dilation)
+                                  const Size2D &dilation, unsigned int num_groups)
 {
     // Create reference
     SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.quantization_info() };
@@ -115,20 +122,20 @@
         SimpleTensor<T> weights_nchw = reference::permute<T>(weights, PermutationVector(1U, 2U, 0U));
         SimpleTensor<T> dst_nchw     = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U));
 
-        return reference::permute<T>(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation), PermutationVector(2U, 0U, 1U));
+        return reference::permute<T>(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation, num_groups), PermutationVector(2U, 0U, 1U));
     }
     else
     {
-        return convolution_layer_nchw(src, weights, bias, dst, info, dilation);
+        return convolution_layer_nchw(src, weights, bias, dst, info, dilation, num_groups);
     }
 }
 
 template SimpleTensor<float> convolution_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &output_shape,
-                                               const PadStrideInfo &info, const Size2D &dilation);
+                                               const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups);
 template SimpleTensor<half> convolution_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, const TensorShape &output_shape,
-                                              const PadStrideInfo &info, const Size2D &dilation);
+                                              const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups);
 template SimpleTensor<uint8_t> convolution_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &output_shape,
-                                                 const PadStrideInfo &info, const Size2D &dilation);
+                                                 const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups);
 } // namespace reference
 } // namespace validation
 } // namespace test
