COMPMID-417: Fix validation

Change-Id: I7a745037136bc6e02d177f65fe4f4cd43873b98e
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/87406
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/tests/validation/CL/SoftmaxLayer.cpp b/tests/validation/CL/SoftmaxLayer.cpp
index 6a22eb1..8c143ec 100644
--- a/tests/validation/CL/SoftmaxLayer.cpp
+++ b/tests/validation/CL/SoftmaxLayer.cpp
@@ -48,7 +48,7 @@
 RelativeTolerance<float>            tolerance_f32(0.001f);
 
 /** Tolerance for fixed point operations */
-constexpr AbsoluteTolerance<int8_t> tolerance_fixed_point(2);
+constexpr AbsoluteTolerance<int16_t> tolerance_fixed_point(2);
 
 /** CNN data types */
 const auto CNNDataTypes = framework::dataset::make("DataType",
@@ -145,15 +145,17 @@
 
 TEST_SUITE(QS16)
 // Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14
-FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
-                                                                                                                      DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                                      framework::dataset::make("DataType",
+                                                                                                                              DataType::QS16)),
                                                                                                                       framework::dataset::make("FractionalBits", 1, 14)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_fixed_point);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
-                                                                                                                    DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                                    framework::dataset::make("DataType",
+                                                                                                                            DataType::QS16)),
                                                                                                                     framework::dataset::make("FractionalBits", 1, 14)))
 {
     // Validate output
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index 36f1881..7ac7759 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -49,7 +49,7 @@
 constexpr AbsoluteTolerance<float> tolerance_f16(0.0001f);
 #endif /* ARM_COMPUTE_ENABLE_FP16*/
 /** Tolerance for fixed point operations */
-constexpr AbsoluteTolerance<int8_t> tolerance_fixed_point(2);
+constexpr AbsoluteTolerance<int16_t> tolerance_fixed_point(2);
 
 /** CNN data types */
 const auto CNNDataTypes = framework::dataset::make("DataType",
@@ -151,15 +151,17 @@
 
 TEST_SUITE(QS16)
 // Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
-                                                                                                                      DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                                      framework::dataset::make("DataType",
+                                                                                                                              DataType::QS16)),
                                                                                                                       framework::dataset::make("FractionalBits", 1, 14)))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_fixed_point);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
-                                                                                                                    DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                                    framework::dataset::make("DataType",
+                                                                                                                            DataType::QS16)),
                                                                                                                     framework::dataset::make("FractionalBits", 1, 14)))
 {
     // Validate output
diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp
index 690c4ea..1a08211 100644
--- a/tests/validation/Validation.cpp
+++ b/tests/validation/Validation.cpp
@@ -130,7 +130,7 @@
         const double target         = get_double_data(ptr + channel_offset, tensor.data_type());
         const double reference      = get_double_data(static_cast<const uint8_t *>(border_value) + channel_offset, tensor.data_type());
 
-        if(!compare<AbsoluteTolerance<double>, double>(target, reference))
+        if(!compare<AbsoluteTolerance<double>>(target, reference))
         {
             ARM_COMPUTE_TEST_INFO("id = " << id);
             ARM_COMPUTE_TEST_INFO("channel = " << channel);
@@ -192,7 +192,7 @@
             const double target         = get_double_data(ptr + channel_offset, tensor.data_type());
             const double reference      = get_double_data(reference_value, tensor.data_type());
 
-            if(!compare<AbsoluteTolerance<double>, double>(target, reference))
+            if(!compare<AbsoluteTolerance<double>>(target, reference))
             {
                 ARM_COMPUTE_TEST_INFO("id = " << id);
                 ARM_COMPUTE_TEST_INFO("channel = " << channel);
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index e70c970..6bc42a4 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -226,11 +226,11 @@
     T                      _tolerance{};
 };
 
-template <typename T, typename U>
+template <typename T>
 struct compare;
 
 template <typename U>
-struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<U>>
+struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>>
 {
     using compare_base<AbsoluteTolerance<U>>::compare_base;
 
@@ -245,12 +245,16 @@
             return true;
         }
 
-        return static_cast<U>(std::abs(this->_target - this->_reference)) <= static_cast<U>(this->_tolerance);
+        using comparison_type = typename std::conditional<std::is_integral<U>::value, int64_t, U>::type;
+
+        const comparison_type abs_difference(std::abs(static_cast<comparison_type>(this->_target) - static_cast<comparison_type>(this->_reference)));
+
+        return abs_difference <= static_cast<comparison_type>(this->_tolerance);
     }
 };
 
 template <typename U>
-struct compare<RelativeTolerance<U>, U> : public compare_base<RelativeTolerance<U>>
+struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>>
 {
     using compare_base<RelativeTolerance<U>>::compare_base;
 
@@ -325,7 +329,7 @@
                 const T &target_value    = reinterpret_cast<const T *>(tensor(id))[c];
                 const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
 
-                if(!compare<U, typename U::value_type>(target_value, reference_value, tolerance_value))
+                if(!compare<U>(target_value, reference_value, tolerance_value))
                 {
                     ARM_COMPUTE_TEST_INFO("id = " << id);
                     ARM_COMPUTE_TEST_INFO("channel = " << c);
@@ -359,7 +363,7 @@
     ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference));
     ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target));
     ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance)));
-    ARM_COMPUTE_EXPECT((compare<U, typename U::value_type>(target, reference, tolerance)), framework::LogLevel::ERRORS);
+    ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS);
 }
 } // namespace validation
 } // namespace test