COMPMID-653 - Arithmetic Subtraction, add support different datatype

Change-Id: I2b3d65c8d8a85ad67b9972713d06f047f5bcd1ae
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/93693
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/tests/validation/CL/ArithmeticSubtraction.cpp b/tests/validation/CL/ArithmeticSubtraction.cpp
index 817a31f..5e7d741 100644
--- a/tests/validation/CL/ArithmeticSubtraction.cpp
+++ b/tests/validation/CL/ArithmeticSubtraction.cpp
@@ -47,8 +47,14 @@
 const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
                                                     framework::dataset::make("DataType",
                                                                              DataType::U8));
-const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
+const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::S16)),
                                                      framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
+                                                         framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)),
+                                                          framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)),
+                                                          framework::dataset::make("DataType", DataType::S16));
 const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)),
                                                      framework::dataset::make("DataType", DataType::QS8));
 const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)),
@@ -62,8 +68,8 @@
 TEST_SUITE(CL)
 TEST_SUITE(ArithmeticSubtraction)
 
-template <typename T>
-using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(U8)
 DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
@@ -89,7 +95,8 @@
     validate(dst.info()->padding(), padding);
 }
 
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                                     ArithmeticSubtractionU8Dataset),
                                                                                                                      framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
@@ -97,14 +104,19 @@
 }
 TEST_SUITE_END()
 
+template <typename T1, typename T2 = T1>
+using CLArithmeticSubtractionToS16Fixture = CLArithmeticSubtractionFixture<T1, T2, int16_t>;
+
 TEST_SUITE(S16)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
+                                                                                   framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+                                                                           framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
-               shape, data_type, policy)
+               shape, data_type1, data_type2, policy)
 {
     // Create tensors
-    CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type);
-    CLTensor ref_src2 = create_tensor<CLTensor>(shape, DataType::S16);
+    CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type1);
+    CLTensor ref_src2 = create_tensor<CLTensor>(shape, data_type2);
     CLTensor dst      = create_tensor<CLTensor>(shape, DataType::S16);
 
     // Create and Configure function
@@ -121,28 +133,86 @@
     validate(ref_src2.info()->padding(), padding);
     validate(dst.info()->padding(), padding);
 }
-
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+TEST_SUITE(S16_S16_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                   framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 TEST_SUITE_END()
 
-template <typename T>
-using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
+TEST_SUITE(U8_U8_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionU8U8S16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                                        ArithmeticSubtractionU8U8S16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(S16_U8_S16)
+using CLAriSubS16U8ToS16Fixture = CLArithmeticSubtractionToS16Fixture<int16_t, uint8_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(U8_S16_S16)
+using CLAriSubU8S16ToS16Fixture = CLArithmeticSubtractionToS16Fixture<uint8_t, int16_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(Quantized)
 TEST_SUITE(QS8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -150,7 +220,8 @@
     validate(CLAccessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -169,7 +240,8 @@
     validate(CLAccessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS16Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 15)))
 {
diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp
index 80bdb15..bed2d37 100644
--- a/tests/validation/CPP/ArithmeticSubtraction.cpp
+++ b/tests/validation/CPP/ArithmeticSubtraction.cpp
@@ -34,23 +34,26 @@
 {
 namespace reference
 {
-template <typename T>
-SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
 {
-    SimpleTensor<T> result(src1.shape(), dst_data_type);
+    SimpleTensor<T3> result(src1.shape(), dst_data_type);
 
-    using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type;
+    using intermediate_type = typename common_promoted_signed_type<typename std::conditional<sizeof(T1) >= sizeof(T2), T1, T2>::type >::intermediate_type;
 
     for(int i = 0; i < src1.num_elements(); ++i)
     {
         intermediate_type val = static_cast<intermediate_type>(src1[i]) - static_cast<intermediate_type>(src2[i]);
-        result[i]             = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val);
+        result[i]             = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(val) : static_cast<T3>(val);
     }
 
     return result;
 }
 
 template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
 template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
 template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
 template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
diff --git a/tests/validation/CPP/ArithmeticSubtraction.h b/tests/validation/CPP/ArithmeticSubtraction.h
index 18b0d12..9308314 100644
--- a/tests/validation/CPP/ArithmeticSubtraction.h
+++ b/tests/validation/CPP/ArithmeticSubtraction.h
@@ -35,8 +35,8 @@
 {
 namespace reference
 {
-template <typename T>
-SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp
index dcaf9d9..fcd415b 100644
--- a/tests/validation/NEON/ArithmeticSubtraction.cpp
+++ b/tests/validation/NEON/ArithmeticSubtraction.cpp
@@ -49,6 +49,12 @@
                                                                              DataType::U8));
 const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
                                                      framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
+                                                         framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)),
+                                                          framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)),
+                                                          framework::dataset::make("DataType", DataType::S16));
 const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)),
                                                      framework::dataset::make("DataType", DataType::QS8));
 const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)),
@@ -64,8 +70,8 @@
 TEST_SUITE(NEON)
 TEST_SUITE(ArithmeticSubtraction)
 
-template <typename T>
-using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(U8)
 DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
@@ -99,14 +105,19 @@
 }
 TEST_SUITE_END()
 
+template <typename T1, typename T2 = T1>
+using NEArithmeticSubtractionToS16Fixture = NEArithmeticSubtractionFixture<T1, T2, int16_t>;
+
 TEST_SUITE(S16)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
+                                                                                   framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+                                                                           framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
-               shape, data_type, policy)
+               shape, data_type1, data_type2, policy)
 {
     // Create tensors
-    Tensor ref_src1 = create_tensor<Tensor>(shape, data_type);
-    Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::S16);
+    Tensor ref_src1 = create_tensor<Tensor>(shape, data_type1);
+    Tensor ref_src2 = create_tensor<Tensor>(shape, data_type2);
     Tensor dst      = create_tensor<Tensor>(shape, DataType::S16);
 
     // Create and Configure function
@@ -124,27 +135,86 @@
     validate(dst.info()->padding(), padding);
 }
 
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+TEST_SUITE(S16_S16_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                   framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
 TEST_SUITE_END()
 
-template <typename T>
-using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
+TEST_SUITE(U8_U8_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionU8U8S16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                                        ArithmeticSubtractionU8U8S16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(S16_U8_S16)
+using NEAriSubS16U8ToS16Fixture = NEArithmeticSubtractionToS16Fixture<int16_t, uint8_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(U8_S16_S16)
+using NEAriSubU8S16ToS16Fixture = NEArithmeticSubtractionToS16Fixture<uint8_t, int16_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(Quantized)
 TEST_SUITE(QS8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -152,7 +222,8 @@
     validate(Accessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -171,7 +242,8 @@
     validate(Accessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS16Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 15)))
 {
diff --git a/tests/validation/fixtures/ArithmeticSubtractionFixture.h b/tests/validation/fixtures/ArithmeticSubtractionFixture.h
index 2c683d6..9e65fae 100644
--- a/tests/validation/fixtures/ArithmeticSubtractionFixture.h
+++ b/tests/validation/fixtures/ArithmeticSubtractionFixture.h
@@ -40,7 +40,7 @@
 {
 namespace validation
 {
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1>
 class ArithmeticSubtractionValidationFixedPointFixture : public framework::Fixture
 {
 public:
@@ -93,31 +93,31 @@
         return dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position)
+    SimpleTensor<T3> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position)
     {
         // Create reference
-        SimpleTensor<T> ref_src1{ shape, data_type0, 1, fixed_point_position };
-        SimpleTensor<T> ref_src2{ shape, data_type1, 1, fixed_point_position };
+        SimpleTensor<T1> ref_src1{ shape, data_type0, 1, fixed_point_position };
+        SimpleTensor<T2> ref_src2{ shape, data_type1, 1, fixed_point_position };
 
         // Fill reference
         fill(ref_src1, 0);
         fill(ref_src2, 1);
 
-        return reference::arithmetic_subtraction<T>(ref_src1, ref_src2, output_data_type, convert_policy);
+        return reference::arithmetic_subtraction<T1, T2, T3>(ref_src1, ref_src2, output_data_type, convert_policy);
     }
 
-    TensorType      _target{};
-    SimpleTensor<T> _reference{};
-    int             _fractional_bits{};
+    TensorType       _target{};
+    SimpleTensor<T3> _reference{};
+    int              _fractional_bits{};
 };
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1>
+class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
 {
 public:
     template <typename...>
     void setup(TensorShape shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
     {
-        ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0);
+        ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0);
     }
 };
 } // namespace validation