Fix fill() for FP data type in fixtures - Part 2

Resolves: COMPMID-4056

Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Change-Id: I6623eb9c0e66e52af4e0e9fb386031f4a09125b7
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4722
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/fixtures/AccumulateFixture.h b/tests/validation/fixtures/AccumulateFixture.h
index 8fa6689..7cea29c 100644
--- a/tests/validation/fixtures/AccumulateFixture.h
+++ b/tests/validation/fixtures/AccumulateFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -139,8 +139,8 @@
     template <typename...>
     void setup(TensorShape shape, DataType data_type, DataType output_data_type)
     {
-        std::mt19937                     gen(library->seed());
-        std::uniform_real_distribution<> float_dist(0, 1);
+        std::mt19937                          gen(library->seed());
+        std::uniform_real_distribution<float> float_dist(0, 1);
 
         _alpha = float_dist(gen);
 
diff --git a/tests/validation/fixtures/ArithmeticDivisionFixture.h b/tests/validation/fixtures/ArithmeticDivisionFixture.h
index 713a6db..60adbfd 100644
--- a/tests/validation/fixtures/ArithmeticDivisionFixture.h
+++ b/tests/validation/fixtures/ArithmeticDivisionFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -55,7 +55,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(1.0f, 5.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(1.0f), T(5.0f) };
         library->fill(tensor, distribution, i);
     }
 
diff --git a/tests/validation/fixtures/BatchNormalizationLayerFixture.h b/tests/validation/fixtures/BatchNormalizationLayerFixture.h
index 8a6caac..8685543 100644
--- a/tests/validation/fixtures/BatchNormalizationLayerFixture.h
+++ b/tests/validation/fixtures/BatchNormalizationLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -59,10 +59,14 @@
     template <typename U>
     void fill(U &&src_tensor, U &&mean_tensor, U &&var_tensor, U &&beta_tensor, U &&gamma_tensor)
     {
-        const float                      min_bound = -1.f;
-        const float                      max_bound = 1.f;
-        std::uniform_real_distribution<> distribution(min_bound, max_bound);
-        std::uniform_real_distribution<> distribution_var(0, max_bound);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        const T          min_bound = T(-1.f);
+        const T          max_bound = T(1.f);
+        DistributionType distribution{ min_bound, max_bound };
+        DistributionType distribution_var{ T(0.f), max_bound };
+
         library->fill(src_tensor, distribution, 0);
         library->fill(mean_tensor, distribution, 1);
         library->fill(var_tensor, distribution_var, 0);
@@ -73,7 +77,7 @@
         else
         {
             // Fill with default value 0.f
-            library->fill_tensor_value(beta_tensor, 0.f);
+            library->fill_tensor_value(beta_tensor, T(0.f));
         }
         if(_use_gamma)
         {
@@ -82,7 +86,7 @@
         else
         {
             // Fill with default value 1.f
-            library->fill_tensor_value(gamma_tensor, 1.f);
+            library->fill_tensor_value(gamma_tensor, T(1.f));
         }
     }
 
diff --git a/tests/validation/fixtures/BatchNormalizationLayerFusionFixture.h b/tests/validation/fixtures/BatchNormalizationLayerFusionFixture.h
index 2df7f47..3f7f97a 100644
--- a/tests/validation/fixtures/BatchNormalizationLayerFusionFixture.h
+++ b/tests/validation/fixtures/BatchNormalizationLayerFusionFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -65,16 +65,19 @@
     template <typename U>
     void fill(U &&src, U &&w_tensor, U &&b_tensor, U &&mean_tensor, U &&var_tensor, U &&beta_tensor, U &&gamma_tensor)
     {
-        std::uniform_real_distribution<> distribution(-1.f, 1.f);
-        std::uniform_real_distribution<> distribution_gz(0, 1.f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.f), T(1.f) };
+        DistributionType distribution_gz{ T(0.f), T(1.f) };
 
         library->fill(src, distribution, 0);
         library->fill(w_tensor, distribution, 1);
         library->fill(mean_tensor, distribution, 2);
         library->fill(var_tensor, distribution_gz, 3);
-        _use_conv_b ? library->fill(b_tensor, distribution, 4) : library->fill_tensor_value(b_tensor, 0.f);
-        _use_beta ? library->fill(beta_tensor, distribution, 5) : library->fill_tensor_value(beta_tensor, 0.f);
-        _use_gamma ? library->fill(gamma_tensor, distribution, 6) : library->fill_tensor_value(gamma_tensor, 1.f);
+        _use_conv_b ? library->fill(b_tensor, distribution, 4) : library->fill_tensor_value(b_tensor, T(0.f));
+        _use_beta ? library->fill(beta_tensor, distribution, 5) : library->fill_tensor_value(beta_tensor, T(0.f));
+        _use_gamma ? library->fill(gamma_tensor, distribution, 6) : library->fill_tensor_value(gamma_tensor, T(1.f));
     }
 
     TensorType compute_target(TensorShape src_shape, TensorShape w_shape, TensorShape b_shape, TensorShape dst_shape, PadStrideInfo info, float epsilon)
diff --git a/tests/validation/fixtures/BatchToSpaceLayerFixture.h b/tests/validation/fixtures/BatchToSpaceLayerFixture.h
index ca6d20a..a8d1327 100644
--- a/tests/validation/fixtures/BatchToSpaceLayerFixture.h
+++ b/tests/validation/fixtures/BatchToSpaceLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -50,7 +50,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
     TensorType compute_target(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape,
@@ -87,8 +90,8 @@
         // Fill tensors
         fill(AccessorType(input), 0);
         {
-            auto block_shape_data = AccessorType(block_shape);
-            const int idx_width       = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+            auto      block_shape_data = AccessorType(block_shape);
+            const int idx_width        = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
             for(unsigned int i = 0; i < block_shape_shape.x(); ++i)
             {
                 static_cast<int32_t *>(block_shape_data.data())[i] = output_shape[i + idx_width] / input_shape[i + idx_width];
diff --git a/tests/validation/fixtures/ConcatenateLayerFixture.h b/tests/validation/fixtures/ConcatenateLayerFixture.h
index e85f81c..d9615ff 100644
--- a/tests/validation/fixtures/ConcatenateLayerFixture.h
+++ b/tests/validation/fixtures/ConcatenateLayerFixture.h
@@ -70,8 +70,8 @@
         {
             qi = QuantizationInfo(1.f / 255.f, offset_dis(gen));
         }
-        std::bernoulli_distribution      mutate_dis(0.5f);
-        std::uniform_real_distribution<> change_dis(-0.25f, 0.f);
+        std::bernoulli_distribution           mutate_dis(0.5f);
+        std::uniform_real_distribution<float> change_dis(-0.25f, 0.f);
 
         // Generate more shapes based on the input
         for(auto &s : shapes)
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 006c5eb..a4db49f 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -301,9 +301,9 @@
     void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
                DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataType weights_data_type)
     {
-        std::vector<float>               weights_scales{};
-        std::mt19937                     gen(library->seed());
-        std::uniform_real_distribution<> dis(0.01f, 1);
+        std::vector<float>                    weights_scales{};
+        std::mt19937                          gen(library->seed());
+        std::uniform_real_distribution<float> dis(0.01f, 1.f);
         for(size_t i = 0; i < output_shape[2]; ++i)
         {
             weights_scales.push_back(dis(gen));
diff --git a/tests/validation/fixtures/DepthToSpaceLayerFixture.h b/tests/validation/fixtures/DepthToSpaceLayerFixture.h
index 8c2f561..bc9954a 100644
--- a/tests/validation/fixtures/DepthToSpaceLayerFixture.h
+++ b/tests/validation/fixtures/DepthToSpaceLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -50,7 +50,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
     TensorType compute_target(TensorShape input_shape, int32_t block_shape, TensorShape output_shape,
diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
index bb1c105..56e9691 100644
--- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
@@ -239,7 +239,7 @@
         {
             case DataType::F32:
             {
-                std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+                std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -460,9 +460,9 @@
         const float out_scale = output_quantization_info.uniform().scale;
         const float in_scale  = input_quantization_info.uniform().scale;
 
-        std::vector<float>               weights_scales{};
-        std::mt19937                     gen(library->seed());
-        std::uniform_real_distribution<> dis(0.01f, out_scale / in_scale);
+        std::vector<float>                    weights_scales{};
+        std::mt19937                          gen(library->seed());
+        std::uniform_real_distribution<float> dis(0.01f, out_scale / in_scale);
         for(size_t i = 0; i < in_shape.z() * depth_multiplier; ++i)
         {
             weights_scales.push_back(dis(gen));
diff --git a/tests/validation/fixtures/ElementWiseUnaryFixture.h b/tests/validation/fixtures/ElementWiseUnaryFixture.h
index f414daf..f8e0dfa 100644
--- a/tests/validation/fixtures/ElementWiseUnaryFixture.h
+++ b/tests/validation/fixtures/ElementWiseUnaryFixture.h
@@ -55,17 +55,20 @@
     template <typename U>
     void fill(U &&tensor, int i, DataType data_type)
     {
+        using FloatType             = typename std::conditional < std::is_same<T, half>::value || std::is_floating_point<T>::value, T, float >::type;
+        using FloatDistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<FloatType>>::type;
+
         switch(_op)
         {
             case ElementWiseUnary::EXP:
             {
-                std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+                FloatDistributionType distribution{ FloatType(-1.0f), FloatType(1.0f) };
                 library->fill(tensor, distribution, i);
                 break;
             }
             case ElementWiseUnary::RSQRT:
             {
-                std::uniform_real_distribution<> distribution(1.0f, 2.0f);
+                FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -82,7 +85,7 @@
                     }
                     case DataType::F32:
                     {
-                        std::uniform_real_distribution<float> distribution(-2.0f, 2.0f);
+                        FloatDistributionType distribution{ FloatType(-2.0f), FloatType(2.0f) };
                         library->fill(tensor, distribution, i);
                         break;
                     }
@@ -99,19 +102,19 @@
             }
             case ElementWiseUnary::LOG:
             {
-                std::uniform_real_distribution<> distribution(0.0000001f, 100.0f);
+                FloatDistributionType distribution{ FloatType(0.0000001f), FloatType(100.0f) };
                 library->fill(tensor, distribution, i);
                 break;
             }
             case ElementWiseUnary::SIN:
             {
-                std::uniform_real_distribution<> distribution(-100.00f, 100.00f);
+                FloatDistributionType distribution{ FloatType(-100.00f), FloatType(100.00f) };
                 library->fill(tensor, distribution, i);
                 break;
             }
             case ElementWiseUnary::ROUND:
             {
-                std::uniform_real_distribution<> distribution(100.0f, -100.0f);
+                FloatDistributionType distribution{ FloatType(100.0f), FloatType(-100.0f) };
                 library->fill(tensor, distribution, i);
                 break;
             }
diff --git a/tests/validation/fixtures/FlattenLayerFixture.h b/tests/validation/fixtures/FlattenLayerFixture.h
index 9627983..d8480ed 100644
--- a/tests/validation/fixtures/FlattenLayerFixture.h
+++ b/tests/validation/fixtures/FlattenLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -66,7 +66,10 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        std::uniform_real_distribution<> distribution(-1.f, 1.f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, 0);
     }
 
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 86d39c0..3e2fb0b 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -89,9 +89,14 @@
             std::uniform_int_distribution<int32_t> distribution(-50, 50);
             library->fill(tensor, distribution, i);
         }
-        else if(is_data_type_float(_data_type))
+        else if(_data_type == DataType::F16)
         {
-            std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+            arm_compute::utils::uniform_real_distribution_fp16 distribution(half(-1.0f), half(1.0f));
+            library->fill(tensor, distribution, i);
+        }
+        else if(_data_type == DataType::F32)
+        {
+            std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
             library->fill(tensor, distribution, i);
         }
         else
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 192e77e..1dc2af2 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -196,11 +196,14 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
 
         // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
-        std::uniform_real_distribution<> distribution_inf(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
+        DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
         library->fill_borders_with_garbage(tensor, distribution_inf, i);
     }
 
@@ -313,7 +316,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
 
@@ -436,11 +442,14 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
 
         // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
-        std::uniform_real_distribution<> distribution_inf(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
+        DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
         library->fill_borders_with_garbage(tensor, distribution_inf, i);
     }
 
@@ -579,7 +588,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
 
@@ -718,11 +730,14 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
 
         // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
-        std::uniform_real_distribution<> distribution_inf(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
+        DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
         library->fill_borders_with_garbage(tensor, distribution_inf, i);
     }
 
@@ -887,7 +902,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
 
@@ -1047,11 +1065,14 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
 
         // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
-        std::uniform_real_distribution<> distribution_inf(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
+        DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
         library->fill_borders_with_garbage(tensor, distribution_inf, i);
     }
 
@@ -1199,7 +1220,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
 
@@ -1346,11 +1370,14 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
 
         // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
-        std::uniform_real_distribution<> distribution_inf(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
+        DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
         library->fill_borders_with_garbage(tensor, distribution_inf, i);
     }
 
@@ -1474,7 +1501,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
 
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 24c1a24..95f4960 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -218,10 +218,10 @@
 
         if(data_type_b == DataType::QSYMM8_PER_CHANNEL)
         {
-            output_stage.is_quantized_per_channel         = true;
-            const size_t                     num_channels = shape_b[0];
-            std::vector<float>               scales(num_channels);
-            std::uniform_real_distribution<> distribution(0, 1);
+            output_stage.is_quantized_per_channel              = true;
+            const size_t                          num_channels = shape_b[0];
+            std::vector<float>                    scales(num_channels);
+            std::uniform_real_distribution<float> distribution(0.f, 1.f);
             library->fill(scales, distribution, 0);
             output_stage.gemmlowp_multipliers.resize(num_channels);
             output_stage.gemmlowp_shifts.resize(num_channels);
diff --git a/tests/validation/fixtures/InstanceNormalizationLayerFixture.h b/tests/validation/fixtures/InstanceNormalizationLayerFixture.h
index 06ff4d3..3f2853d 100644
--- a/tests/validation/fixtures/InstanceNormalizationLayerFixture.h
+++ b/tests/validation/fixtures/InstanceNormalizationLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -55,7 +55,10 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        std::uniform_real_distribution<> distribution(1.f, 2.f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(1.0f), T(2.0f) };
         library->fill(tensor, distribution, 0);
     }
 
diff --git a/tests/validation/fixtures/L2NormalizeLayerFixture.h b/tests/validation/fixtures/L2NormalizeLayerFixture.h
index c617f10..c3692b3 100644
--- a/tests/validation/fixtures/L2NormalizeLayerFixture.h
+++ b/tests/validation/fixtures/L2NormalizeLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -59,7 +59,10 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        std::uniform_real_distribution<> distribution(1.f, 2.f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(1.0f), T(2.0f) };
         library->fill(tensor, distribution, 0);
     }
 
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index bf785bb..2b321f5 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -61,13 +61,19 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
     template <typename U>
     void fill_custom_val(U &&tensor, float num, int i)
     {
-        std::uniform_real_distribution<> distribution(num, num);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(num), T(num) };
         library->fill(tensor, distribution, i);
     }
     TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
diff --git a/tests/validation/fixtures/MaxUnpoolingLayerFixture.h b/tests/validation/fixtures/MaxUnpoolingLayerFixture.h
index 086bd6c..49b4c4b 100644
--- a/tests/validation/fixtures/MaxUnpoolingLayerFixture.h
+++ b/tests/validation/fixtures/MaxUnpoolingLayerFixture.h
@@ -65,9 +65,14 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(!is_data_type_quantized(tensor.data_type()))
+        if(tensor.data_type() == DataType::F32)
         {
-            std::uniform_real_distribution<> distribution(-1.f, 1.f);
+            std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+            library->fill(tensor, distribution, 0);
+        }
+        else if(tensor.data_type() == DataType::F16)
+        {
+            arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(-1.0f), half(1.0f) };
             library->fill(tensor, distribution, 0);
         }
         else // data type is quantized_asymmetric
diff --git a/tests/validation/fixtures/MeanStdDevFixture.h b/tests/validation/fixtures/MeanStdDevFixture.h
index ec0599b..c76d7af 100644
--- a/tests/validation/fixtures/MeanStdDevFixture.h
+++ b/tests/validation/fixtures/MeanStdDevFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -50,9 +50,14 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(is_data_type_float(tensor.data_type()))
+        if(tensor.data_type() == DataType::F32)
         {
-            std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+            std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+            library->fill(tensor, distribution, 0);
+        }
+        else if(tensor.data_type() == DataType::F16)
+        {
+            arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(-1.0f), half(1.0f) };
             library->fill(tensor, distribution, 0);
         }
         else
diff --git a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h
index 47aa38e..1f1e924 100644
--- a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h
+++ b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -56,9 +56,10 @@
     template <typename U>
     void fill(U &&src_tensor)
     {
-        const float                      min_bound = -1.f;
-        const float                      max_bound = 1.f;
-        std::uniform_real_distribution<> distribution(min_bound, max_bound);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(src_tensor, distribution, 0);
     }
 
diff --git a/tests/validation/fixtures/NonMaxSuppressionFixture.h b/tests/validation/fixtures/NonMaxSuppressionFixture.h
index de5d6d5..6d5fc43 100644
--- a/tests/validation/fixtures/NonMaxSuppressionFixture.h
+++ b/tests/validation/fixtures/NonMaxSuppressionFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -59,9 +59,9 @@
 
 protected:
     template <typename U>
-    void fill(U &&tensor, int i, int lo, int hi)
+    void fill(U &&tensor, int i, float lo, float hi)
     {
-        std::uniform_real_distribution<> distribution(lo, hi);
+        std::uniform_real_distribution<float> distribution(lo, hi);
         library->fill_boxes(tensor, distribution, i);
     }
 
diff --git a/tests/validation/fixtures/NormalizationLayerFixture.h b/tests/validation/fixtures/NormalizationLayerFixture.h
index 54dfd59..765e93e 100644
--- a/tests/validation/fixtures/NormalizationLayerFixture.h
+++ b/tests/validation/fixtures/NormalizationLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -59,7 +59,10 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, 0);
     }
 
diff --git a/tests/validation/fixtures/NormalizePlanarYUVLayerFixture.h b/tests/validation/fixtures/NormalizePlanarYUVLayerFixture.h
index bd84692..0189261 100644
--- a/tests/validation/fixtures/NormalizePlanarYUVLayerFixture.h
+++ b/tests/validation/fixtures/NormalizePlanarYUVLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -56,12 +56,14 @@
     template <typename U>
     void fill(U &&src_tensor, U &&mean_tensor, U &&std_tensor)
     {
+        using FloatDistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<float>>::type;
+
         if(is_data_type_float(_data_type))
         {
-            const float                      min_bound = -1.f;
-            const float                      max_bound = 1.f;
-            std::uniform_real_distribution<> distribution(min_bound, max_bound);
-            std::uniform_real_distribution<> distribution_std(0.1, max_bound);
+            const T               min_bound = T(-1.f);
+            const T               max_bound = T(1.f);
+            FloatDistributionType distribution(min_bound, max_bound);
+            FloatDistributionType distribution_std(T(0.1f), max_bound);
             library->fill(src_tensor, distribution, 0);
             library->fill(mean_tensor, distribution, 1);
             library->fill(std_tensor, distribution_std, 2);
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 9cd1c46..3653de7 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -65,9 +65,14 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(!is_data_type_quantized(tensor.data_type()))
+        if(tensor.data_type() == DataType::F32)
         {
-            std::uniform_real_distribution<> distribution(-1.f, 1.f);
+            std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+            library->fill(tensor, distribution, 0);
+        }
+        else if(tensor.data_type() == DataType::F16)
+        {
+            arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(-1.0f), half(1.0f) };
             library->fill(tensor, distribution, 0);
         }
         else // data type is quantized_asymmetric
diff --git a/tests/validation/fixtures/RNNLayerFixture.h b/tests/validation/fixtures/RNNLayerFixture.h
index 1668e94..f1d0d69 100644
--- a/tests/validation/fixtures/RNNLayerFixture.h
+++ b/tests/validation/fixtures/RNNLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -54,7 +54,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
 
diff --git a/tests/validation/fixtures/RangeFixture.h b/tests/validation/fixtures/RangeFixture.h
index 604007d..0713db9 100644
--- a/tests/validation/fixtures/RangeFixture.h
+++ b/tests/validation/fixtures/RangeFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -65,9 +65,9 @@
 protected:
     float get_random_end(const DataType output_data_type, const QuantizationInfo qinfo_out, float start, float step)
     {
-        std::uniform_real_distribution<> distribution(1, 100);
-        std::mt19937                     gen(library->seed());
-        float                            end = start;
+        std::uniform_real_distribution<float> distribution(1, 100);
+        std::mt19937                          gen(library->seed());
+        float                                 end = start;
         switch(output_data_type)
         {
             case DataType::U8:
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h
index 7288761..b7ba942 100644
--- a/tests/validation/fixtures/ReduceMeanFixture.h
+++ b/tests/validation/fixtures/ReduceMeanFixture.h
@@ -58,18 +58,27 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(!is_data_type_quantized(tensor.data_type()))
+        if(tensor.data_type() == DataType::F32)
         {
-            std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+            std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
             library->fill(tensor, distribution, 0);
         }
-        else
+        else if(tensor.data_type() == DataType::F16)
+        {
+            arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(-1.0f), half(1.0f) };
+            library->fill(tensor, distribution, 0);
+        }
+        else if(is_data_type_quantized(tensor.data_type()))
         {
             std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
             std::uniform_int_distribution<> distribution(bounds.first, bounds.second);
 
             library->fill(tensor, distribution, 0);
         }
+        else
+        {
+            library->fill_tensor_uniform(tensor, 0);
+        }
     }
 
     TensorType compute_target(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h
index 646518d..a8dff1b 100644
--- a/tests/validation/fixtures/ReductionOperationFixture.h
+++ b/tests/validation/fixtures/ReductionOperationFixture.h
@@ -61,12 +61,17 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(!is_data_type_quantized(tensor.data_type()))
+        if(tensor.data_type() == DataType::F32)
         {
-            std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+            std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
             library->fill(tensor, distribution, 0);
         }
-        else
+        else if(tensor.data_type() == DataType::F16)
+        {
+            arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(-1.0f), half(1.0f) };
+            library->fill(tensor, distribution, 0);
+        }
+        else if(is_data_type_quantized(tensor.data_type()))
         {
             if(tensor.data_type() == DataType::QASYMM8)
             {
@@ -87,6 +92,10 @@
                 ARM_COMPUTE_ERROR("Not supported");
             }
         }
+        else
+        {
+            library->fill_tensor_uniform(tensor, 0);
+        }
     }
 
     TensorType compute_target(const TensorShape &src_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info)
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h
index 1e66306..fc09c8f 100644
--- a/tests/validation/fixtures/ScaleFixture.h
+++ b/tests/validation/fixtures/ScaleFixture.h
@@ -98,9 +98,15 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(is_data_type_float(_data_type))
+        if(tensor.data_type() == DataType::F32)
         {
-            library->fill_tensor_uniform(tensor, 0);
+            std::uniform_real_distribution<float> distribution(-5.0f, 5.0f);
+            library->fill(tensor, distribution, 0);
+        }
+        else if(tensor.data_type() == DataType::F16)
+        {
+            arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(-5.0f), half(5.0f) };
+            library->fill(tensor, distribution, 0);
         }
         else if(is_data_type_quantized(tensor.data_type()))
         {
@@ -109,9 +115,7 @@
         }
         else
         {
-            // Restrict range for float to avoid any floating point issues
-            std::uniform_real_distribution<> distribution(-5.0f, 5.0f);
-            library->fill(tensor, distribution, 0);
+            library->fill_tensor_uniform(tensor, 0);
         }
     }
 
diff --git a/tests/validation/fixtures/SoftmaxLayerFixture.h b/tests/validation/fixtures/SoftmaxLayerFixture.h
index 30356d6..ff5003d 100644
--- a/tests/validation/fixtures/SoftmaxLayerFixture.h
+++ b/tests/validation/fixtures/SoftmaxLayerFixture.h
@@ -59,16 +59,25 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(!is_data_type_quantized(tensor.data_type()))
+        if(tensor.data_type() == DataType::F32)
         {
-            std::uniform_real_distribution<> distribution(-10.f, 10.f);
+            std::uniform_real_distribution<float> distribution(-10.0f, 10.0f);
             library->fill(tensor, distribution, 0);
         }
-        else // data type is quantized_asymmetric (signed or unsigned)
+        else if(tensor.data_type() == DataType::F16)
+        {
+            arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(-10.0f), half(10.0f) };
+            library->fill(tensor, distribution, 0);
+        }
+        else if(!is_data_type_quantized(tensor.data_type()))
         {
             std::uniform_int_distribution<> distribution(0, 100);
             library->fill(tensor, distribution, 0);
         }
+        else
+        {
+            library->fill_tensor_uniform(tensor, 0);
+        }
     }
 
     TensorType compute_target(const TensorShape &shape, DataType data_type,
diff --git a/tests/validation/fixtures/SpaceToDepthFixture.h b/tests/validation/fixtures/SpaceToDepthFixture.h
index 24ae020..b261fd5 100644
--- a/tests/validation/fixtures/SpaceToDepthFixture.h
+++ b/tests/validation/fixtures/SpaceToDepthFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -50,7 +50,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
     TensorType compute_target(TensorShape input_shape, TensorShape output_shape, const int block_shape,
diff --git a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
index 4ac19bf..c3aa63b 100644
--- a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
+++ b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
@@ -408,7 +408,7 @@
         {
             case DataType::F32:
             {
-                std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+                std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
                 library->fill(tensor, distribution, i);
                 break;
             }
diff --git a/tests/validation/fixtures/UNIT/WeightsRetentionFixture.h b/tests/validation/fixtures/UNIT/WeightsRetentionFixture.h
index 36d338d..8456141 100644
--- a/tests/validation/fixtures/UNIT/WeightsRetentionFixture.h
+++ b/tests/validation/fixtures/UNIT/WeightsRetentionFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -64,7 +64,10 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        std::uniform_real_distribution<> distribution(0.5f, 1.f);
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(0.5f), T(1.0f) };
         library->fill(tensor, distribution, i);
     }
 
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
index e1cc953..1061fd0 100644
--- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
@@ -72,9 +72,14 @@
         switch(tensor.data_type())
         {
             case DataType::F16:
+            {
+                arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(min), half(max) };
+                library->fill(tensor, distribution, i);
+                break;
+            }
             case DataType::F32:
             {
-                std::uniform_real_distribution<> distribution(min, max);
+                std::uniform_real_distribution<float> distribution(min, max);
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -183,7 +188,7 @@
             }
             case DataType::F32:
             {
-                std::uniform_real_distribution<> distribution(min, max);
+                std::uniform_real_distribution<float> distribution(min, max);
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -338,9 +343,14 @@
         switch(tensor.data_type())
         {
             case DataType::F16:
+            {
+                arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(min), half(max) };
+                library->fill(tensor, distribution, i);
+                break;
+            }
             case DataType::F32:
             {
-                std::uniform_real_distribution<> distribution(min, max);
+                std::uniform_real_distribution<float> distribution(min, max);
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -420,9 +430,14 @@
         switch(tensor.data_type())
         {
             case DataType::F16:
+            {
+                arm_compute::utils::uniform_real_distribution_fp16 distribution{ half(min), half(max) };
+                library->fill(tensor, distribution, i);
+                break;
+            }
             case DataType::F32:
             {
-                std::uniform_real_distribution<> distribution(min, max);
+                std::uniform_real_distribution<float> distribution(min, max);
                 library->fill(tensor, distribution, i);
                 break;
             }