COMPMID-415: Use half_float library for F16

3RDPARTY_UPDATE

Change-Id: Iee572e18d5b1df71300d738cc8690f49d7203d5c
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81353
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index 6ecaccb..58738f8 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -24,10 +24,6 @@
 #ifndef __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__
 #define __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__
 
-#include "RawTensor.h"
-#include "TensorCache.h"
-#include "Utils.h"
-
 #include "arm_compute/core/Coordinates.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
@@ -35,6 +31,10 @@
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Window.h"
+#include "tests/RawTensor.h"
+#include "tests/TensorCache.h"
+#include "tests/Utils.h"
+#include "tests/validation/half.h"
 
 #include <algorithm>
 #include <cstddef>
@@ -43,10 +43,6 @@
 #include <string>
 #include <type_traits>
 
-#if ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif                /* ARM_COMPUTE_ENABLE_FP16 */
-
 namespace arm_compute
 {
 namespace test
@@ -476,9 +472,7 @@
             fill(tensor, distribution_s64, seed_offset);
             break;
         }
-#if ARM_COMPUTE_ENABLE_FP16
         case DataType::F16:
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
         case DataType::F32:
         {
             // It doesn't make sense to check [-inf, inf], so hard code it to a big number
@@ -567,14 +561,12 @@
             fill(tensor, distribution_s64, seed_offset);
             break;
         }
-#if ARM_COMPUTE_ENABLE_FP16
         case DataType::F16:
         {
-            std::uniform_real_distribution<float_t> distribution_f16(low, high);
+            std::uniform_real_distribution<float> distribution_f16(low, high);
             fill(tensor, distribution_f16, seed_offset);
             break;
         }
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
         case DataType::F32:
         {
             ARM_COMPUTE_ERROR_ON(!(std::is_same<float, D>::value));
diff --git a/tests/Utils.h b/tests/Utils.h
index ad45bff..0a58d41 100644
--- a/tests/Utils.h
+++ b/tests/Utils.h
@@ -31,6 +31,7 @@
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
 #include "support/ToolchainSupport.h"
+#include "tests/validation/half.h"
 
 #include <cmath>
 #include <cstddef>
@@ -40,10 +41,6 @@
 #include <string>
 #include <type_traits>
 
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif                /* ARM_COMPUTE_ENABLE_FP16 */
-
 namespace arm_compute
 {
 namespace test
@@ -100,9 +97,7 @@
 template <> struct promote<uint32_t> { using type = uint64_t; };
 template <> struct promote<int32_t> { using type = int64_t; };
 template <> struct promote<float> { using type = float; };
-#ifdef ARM_COMPUTE_ENABLE_FP16
-template <> struct promote<float16_t> { using type = float16_t; };
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+template <> struct promote<half_float::half> { using type = half_float::half; };
 
 
 template <typename T>
@@ -248,11 +243,9 @@
         case DataType::S64:
             *reinterpret_cast<int64_t *>(ptr) = value;
             break;
-#if ARM_COMPUTE_ENABLE_FP16
         case DataType::F16:
-            *reinterpret_cast<float16_t *>(ptr) = value;
+            *reinterpret_cast<half_float::half *>(ptr) = value;
             break;
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
         case DataType::F32:
             *reinterpret_cast<float *>(ptr) = value;
             break;
diff --git a/tests/validation/CL/ArithmeticAddition.cpp b/tests/validation/CL/ArithmeticAddition.cpp
index 6670476..fc1bf59 100644
--- a/tests/validation/CL/ArithmeticAddition.cpp
+++ b/tests/validation/CL/ArithmeticAddition.cpp
@@ -244,7 +244,6 @@
 BOOST_AUTO_TEST_SUITE_END()
 BOOST_AUTO_TEST_SUITE_END()
 
-#ifdef ARM_COMPUTE_ENABLE_FP16
 BOOST_AUTO_TEST_SUITE(F16)
 BOOST_DATA_TEST_CASE(RunSmall, SmallShapes(), shape)
 {
@@ -258,7 +257,6 @@
     validate(CLAccessor(dst), ref_dst);
 }
 BOOST_AUTO_TEST_SUITE_END()
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
 
 BOOST_AUTO_TEST_SUITE(F32)
 BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))
diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp
index 6123571..a3d7140 100644
--- a/tests/validation/CL/ConvolutionLayer.cpp
+++ b/tests/validation/CL/ConvolutionLayer.cpp
@@ -45,6 +45,7 @@
 
 namespace
 {
+const float tolerance_f16 = 1.f;    /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
 const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
 const float tolerance_q   = 1.0f;   /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */
 
@@ -73,7 +74,7 @@
     BOOST_TEST(!dst.info()->is_resizable());
 
     // Fill tensors
-    if(dt == DataType::F32)
+    if(dt == DataType::F32 || dt == DataType::F16)
     {
         std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
         library->fill(CLAccessor(src), distribution, 0);
@@ -134,7 +135,6 @@
     validate(dst.info()->valid_region(), dst_valid_region);
 }
 
-#ifdef ARM_COMPUTE_ENABLE_FP16
 BOOST_AUTO_TEST_SUITE(Float16)
 BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
 BOOST_DATA_TEST_CASE(SmallConvolutionLayer,
@@ -148,10 +148,9 @@
     RawTensor ref_dst = Reference::compute_reference_convolution_layer(conv_set.src_shape, conv_set.weights_shape, conv_set.bias_shape, conv_set.dst_shape, dt, conv_set.info, 0);
 
     // Validate output
-    validate(CLAccessor(dst), ref_dst, tolerance_f32);
+    validate(CLAccessor(dst), ref_dst, tolerance_f16);
 }
 BOOST_AUTO_TEST_SUITE_END()
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
 
 BOOST_AUTO_TEST_SUITE(Float)
 BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index 191e328..2793c22 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -24,21 +24,17 @@
 #ifndef __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
 #define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
 
-#include "ILutAccessor.h"
-#include "Types.h"
-#include "ValidationUserConfiguration.h"
-
 #include "arm_compute/core/Types.h"
+#include "tests/ILutAccessor.h"
+#include "tests/Types.h"
+#include "tests/validation/ValidationUserConfiguration.h"
+#include "tests/validation/half.h"
 
 #include <random>
 #include <type_traits>
 #include <utility>
 #include <vector>
 
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h>
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
 namespace arm_compute
 {
 namespace test
@@ -56,9 +52,7 @@
 inline std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::ActivationFunction activation, int fixed_point_position = 1)
 {
     bool is_float = std::is_same<T, float>::value;
-#ifdef ARM_COMPUTE_ENABLE_FP16
-    is_float = is_float || std::is_same<T, float16_t>::value;
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+    is_float      = is_float || std::is_same<T, half_float::half>::value;
 
     std::pair<T, T> bounds;
 
diff --git a/tests/validation/Reference.cpp b/tests/validation/Reference.cpp
index 1db3c3f..b94a0e5 100644
--- a/tests/validation/Reference.cpp
+++ b/tests/validation/Reference.cpp
@@ -476,15 +476,13 @@
             library->fill(ref_src, distribution, 0);
             break;
         }
-#ifdef ARM_COMPUTE_ENABLE_FP16
         case DataType::F16:
         {
-            const std::pair<float16_t, float16_t> bounds = get_activation_layer_test_bounds<float16_t>(act_info.activation());
+            const std::pair<half_float::half, half_float::half> bounds = get_activation_layer_test_bounds<half_float::half>(act_info.activation());
             std::uniform_real_distribution<> distribution(bounds.first, bounds.second);
             library->fill(ref_src, distribution, 0);
             break;
         }
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
         case DataType::F32:
         {
             const std::pair<float, float> bounds = get_activation_layer_test_bounds<float>(act_info.activation());
@@ -604,9 +602,9 @@
     TensorShape                             dst_shape = calculate_depth_concatenate_shape(shapes);
 
     // Create tensors
-    for(unsigned int i = 0; i < shapes.size(); ++i)
+    for(const auto &shape : shapes)
     {
-        ref_srcs.push_back(support::cpp14::make_unique<RawTensor>(RawTensor(shapes[i], dt, 1, fixed_point_position)));
+        ref_srcs.push_back(support::cpp14::make_unique<RawTensor>(shape, dt, 1, fixed_point_position));
     }
     RawTensor ref_dst(dst_shape, dt, 1, fixed_point_position);
 
diff --git a/tests/validation/TensorFactory.h b/tests/validation/TensorFactory.h
index 2f33dd2..a3bb5f9 100644
--- a/tests/validation/TensorFactory.h
+++ b/tests/validation/TensorFactory.h
@@ -24,29 +24,24 @@
 #ifndef __ARM_COMPUTE_TEST_TENSOR_FACTORY_H__
 #define __ARM_COMPUTE_TEST_TENSOR_FACTORY_H__
 
-#include "RawTensor.h"
-#include "Tensor.h"
 #include "arm_compute/core/Error.h"
+#include "tests/RawTensor.h"
+#include "tests/validation/Tensor.h"
+#include "tests/validation/half.h"
 
 #include "boost_wrapper.h"
 
-#if ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif                /* ARM_COMPUTE_ENABLE_FP16 */
-
 namespace arm_compute
 {
 namespace test
 {
 namespace validation
 {
-using TensorVariant = boost::variant < Tensor<uint8_t>, Tensor<int8_t>,
+using TensorVariant = boost::variant<Tensor<uint8_t>, Tensor<int8_t>,
       Tensor<uint16_t>, Tensor<int16_t>,
       Tensor<uint32_t>, Tensor<int32_t>,
-#ifdef ARM_COMPUTE_ENABLE_FP16
-      Tensor<float16_t>,
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-      Tensor<float >>;
+      Tensor<half_float::half>,
+      Tensor<float>>;
 
 /** Helper to create a constant type if the passed reference is constant. */
 template <typename R, typename T>
@@ -95,12 +90,10 @@
                 using value_type_s32 = typename match_const<R, int32_t>::type;
                 v                    = Tensor<int32_t>(shape, dt, fixed_point_position, reinterpret_cast<value_type_s32 *>(data));
                 break;
-#ifdef ARM_COMPUTE_ENABLE_FP16
             case DataType::F16:
-                using value_type_f16 = typename match_const<R, float16_t>::type;
-                v                    = Tensor<float16_t>(shape, dt, fixed_point_position, reinterpret_cast<value_type_f16 *>(data));
+                using value_type_f16 = typename match_const<R, half_float::half>::type;
+                v                    = Tensor<half_float::half>(shape, dt, fixed_point_position, reinterpret_cast<value_type_f16 *>(data));
                 break;
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
             case DataType::F32:
                 using value_type_f32 = typename match_const<R, float>::type;
                 v                    = Tensor<float>(shape, dt, fixed_point_position, reinterpret_cast<value_type_f32 *>(data));
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 3190478..359dfe8 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -24,18 +24,15 @@
 #ifndef __ARM_COMPUTE_TEST_TENSOR_OPERATIONS_H__
 #define __ARM_COMPUTE_TEST_TENSOR_OPERATIONS_H__
 
-#include "FixedPoint.h"
-#include "Tensor.h"
-#include "Types.h"
-#include "Utils.h"
-#include "support/ToolchainSupport.h"
-
-#include "FixedPoint.h"
-#include "Types.h"
 #include "arm_compute/core/FixedPoint.h"
 #include "arm_compute/core/Types.h"
+#include "support/ToolchainSupport.h"
+#include "tests/Types.h"
+#include "tests/Utils.h"
 #include "tests/validation/FixedPoint.h"
+#include "tests/validation/Tensor.h"
 #include "tests/validation/ValidationUserConfiguration.h"
+#include "tests/validation/half.h"
 
 #include <algorithm>
 #include <array>
@@ -44,26 +41,6 @@
 #include <string>
 #include <vector>
 
-#if ARM_COMPUTE_ENABLE_FP16
-//Beware! most std templates acting on types don't work with the data type float16_t
-namespace std
-{
-template <>
-class numeric_limits<float16_t>
-{
-public:
-    static float16_t lowest()
-    {
-        return -std::numeric_limits<float>::max(); // -inf
-    };
-    static float16_t max()
-    {
-        return std::numeric_limits<float>::max(); // +inf
-    };
-};
-}
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
 namespace arm_compute
 {
 namespace test
@@ -77,11 +54,8 @@
 template <class T>
 struct is_floating_point
     : std::integral_constant < bool,
-      std::is_same<float, typename std::remove_cv<T>::type>::value ||
-#ifdef ARM_COMPUTE_ENABLE_FP16
-      std::is_same<float16_t, typename std::remove_cv<T>::type>::value ||
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-      std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value >
+      std::is_same<float, typename std::remove_cv<T>::type>::value || std::is_same<half_float::half, typename std::remove_cv<T>::type>::value
+      || std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value >
 {
 };
 
@@ -184,7 +158,7 @@
 {
     for(int x = 0; x < cols_weights; ++x)
     {
-        T acc = 0.0f;
+        T acc(0);
         for(int y = 0; y < rows_weights; ++y)
         {
             acc += in[y] * weights[x + y * cols_weights];
@@ -456,8 +430,8 @@
 
     for(int i = 0; i < in1.num_elements(); ++i)
     {
-        intermediate_type val = std::abs(static_cast<intermediate_type>(in1[i]) - static_cast<intermediate_type>(in2[i]));
-        out[i]                = saturate_cast<T3>(val);
+        intermediate_type val(std::abs(static_cast<intermediate_type>(in1[i]) - static_cast<intermediate_type>(in2[i])));
+        out[i] = saturate_cast<T3>(val);
     }
 }
 
@@ -708,7 +682,7 @@
     {
         for(int c = 0; c < N; ++c)
         {
-            T acc = 0.0f;
+            T acc(0);
 
             for(int k = 0; k < K; ++k)
             {
@@ -967,10 +941,10 @@
                 out[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
                 break;
             case ActivationLayerInfo::ActivationFunction::RELU:
-                out[i] = std::max<T>(0, x);
+                out[i] = std::max(static_cast<T>(0), x);
                 break;
             case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
-                out[i] = std::min<T>(a, std::max<T>(0, x));
+                out[i] = std::min<T>(a, std::max(static_cast<T>(0), x));
                 break;
             case ActivationLayerInfo::ActivationFunction::LEAKY_RELU:
                 out[i] = (x > 0) ? x : a * x;
@@ -1519,16 +1493,16 @@
             {
                 for(int w = 0; w < pooled_w; ++w)
                 {
-                    T   avg_val = 0;
-                    int wstart  = w * pool_stride_x - pad_x;
-                    int hstart  = h * pool_stride_y - pad_y;
-                    int wend    = std::min(wstart + pool_size, w_in + pad_x);
-                    int hend    = std::min(hstart + pool_size, h_in + pad_y);
-                    int pool    = (hend - hstart) * (wend - wstart);
-                    wstart      = std::max(wstart, 0);
-                    hstart      = std::max(hstart, 0);
-                    wend        = std::min(wend, w_in);
-                    hend        = std::min(hend, h_in);
+                    T   avg_val(0);
+                    int wstart = w * pool_stride_x - pad_x;
+                    int hstart = h * pool_stride_y - pad_y;
+                    int wend   = std::min(wstart + pool_size, w_in + pad_x);
+                    int hend   = std::min(hstart + pool_size, h_in + pad_y);
+                    int pool   = (hend - hstart) * (wend - wstart);
+                    wstart     = std::max(wstart, 0);
+                    hstart     = std::max(hstart, 0);
+                    wend       = std::min(wend, w_in);
+                    hend       = std::min(hend, h_in);
                     if(is_floating_point<T>::value)
                     {
                         for(int y = hstart; y < hend; ++y)
@@ -1652,7 +1626,7 @@
         }
 
         // Regularize
-        T sum = 0;
+        T sum(0);
         for(int c = 0; c < cols; ++c)
         {
             const T res       = exp(in[r * cols + c] - max);
@@ -1661,7 +1635,7 @@
         }
 
         // Normalize
-        const T norm_val = 1 / sum;
+        const T norm_val = static_cast<T>(1) / sum;
         for(int c = 0; c < cols; ++c)
         {
             out[r * cols + c] *= norm_val;
diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp
index 14ee98a..a13eeb0 100644
--- a/tests/validation/Validation.cpp
+++ b/tests/validation/Validation.cpp
@@ -23,16 +23,16 @@
  */
 #include "Validation.h"
 
-#include "IAccessor.h"
-#include "RawTensor.h"
-#include "TypePrinter.h"
-#include "Utils.h"
-
 #include "arm_compute/core/Coordinates.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/FixedPoint.h"
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/runtime/Tensor.h"
+#include "tests/IAccessor.h"
+#include "tests/RawTensor.h"
+#include "tests/TypePrinter.h"
+#include "tests/Utils.h"
+#include "tests/validation/half.h"
 
 #include <array>
 #include <cmath>
@@ -40,10 +40,6 @@
 #include <cstdint>
 #include <iomanip>
 
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif                /* ARM_COMPUTE_ENABLE_FP16 */
-
 namespace arm_compute
 {
 namespace test
@@ -88,10 +84,8 @@
             return *reinterpret_cast<const uint64_t *>(ptr);
         case DataType::S64:
             return *reinterpret_cast<const int64_t *>(ptr);
-#ifdef ARM_COMPUTE_ENABLE_FP16
         case DataType::F16:
-            return *reinterpret_cast<const float16_t *>(ptr);
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+            return *reinterpret_cast<const half_float::half *>(ptr);
         case DataType::F32:
             return *reinterpret_cast<const float *>(ptr);
         case DataType::F64:
diff --git a/tests/validation/half.h b/tests/validation/half.h
new file mode 100644
index 0000000..fb2235a
--- /dev/null
+++ b/tests/validation/half.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_HALF_H__
+#define __ARM_COMPUTE_TEST_HALF_H__
+
+#ifdef __ANDROID__
+// Android toolchain is broken and doesn't support all CPP11 math functions.
+#define HALF_ENABLE_CPP11_CMATH 0
+#endif /* __ANDROID__ */
+
+#include "half/half.hpp"
+#endif /* __ARM_COMPUTE_TEST_HALF_H__ */
diff --git a/tests/validation_new/Helpers.h b/tests/validation_new/Helpers.h
new file mode 100644
index 0000000..e25b684
--- /dev/null
+++ b/tests/validation_new/Helpers.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
+#define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
+
+#include "tests/validation/half.h"
+
+#include <type_traits>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+template <typename T>
+struct is_floating_point : public std::is_floating_point<T>
+{
+};
+
+template <>
+struct is_floating_point<half_float::half> : public std::true_type
+{
+};
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ */
diff --git a/tests/validation_new/Validation.cpp b/tests/validation_new/Validation.cpp
index 8ab8274..9071663 100644
--- a/tests/validation_new/Validation.cpp
+++ b/tests/validation_new/Validation.cpp
@@ -27,16 +27,13 @@
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/runtime/Tensor.h"
+#include "tests/validation/half.h"
 
 #include <array>
 #include <cmath>
 #include <cstddef>
 #include <cstdint>
 
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif                /* ARM_COMPUTE_ENABLE_FP16 */
-
 namespace arm_compute
 {
 namespace test
@@ -81,10 +78,8 @@
             return *reinterpret_cast<const uint64_t *>(ptr);
         case DataType::S64:
             return *reinterpret_cast<const int64_t *>(ptr);
-#ifdef ARM_COMPUTE_ENABLE_FP16
         case DataType::F16:
-            return *reinterpret_cast<const float16_t *>(ptr);
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+            return *reinterpret_cast<const half_float::half *>(ptr);
         case DataType::F32:
             return *reinterpret_cast<const float *>(ptr);
         case DataType::F64:
diff --git a/tests/validation_new/Validation.h b/tests/validation_new/Validation.h
index 5e947ca..7db7b00 100644
--- a/tests/validation_new/Validation.h
+++ b/tests/validation_new/Validation.h
@@ -85,8 +85,8 @@
  * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
  * other test cases.
  */
-template <typename T, typename U>
-void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = 0, float tolerance_number = 0.f);
+template <typename T, typename U = T>
+void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(0), float tolerance_number = 0.f);
 
 /** Validate tensors with valid region.
  *
@@ -98,8 +98,8 @@
  * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
  * other test cases.
  */
-template <typename T, typename U>
-void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = 0, float tolerance_number = 0.f);
+template <typename T, typename U = T>
+void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(0), float tolerance_number = 0.f);
 
 /** Validate tensors against constant value.
  *