COMPMID-653 - Arithmetic Subtraction, add support different datatype

Change-Id: I2b3d65c8d8a85ad67b9972713d06f047f5bcd1ae
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/93693
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp
index dcaf9d9..fcd415b 100644
--- a/tests/validation/NEON/ArithmeticSubtraction.cpp
+++ b/tests/validation/NEON/ArithmeticSubtraction.cpp
@@ -49,6 +49,12 @@
                                                                              DataType::U8));
 const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
                                                      framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
+                                                         framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)),
+                                                          framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)),
+                                                          framework::dataset::make("DataType", DataType::S16));
 const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)),
                                                      framework::dataset::make("DataType", DataType::QS8));
 const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)),
@@ -64,8 +70,8 @@
 TEST_SUITE(NEON)
 TEST_SUITE(ArithmeticSubtraction)
 
-template <typename T>
-using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(U8)
 DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
@@ -99,14 +105,19 @@
 }
 TEST_SUITE_END()
 
+template <typename T1, typename T2 = T1>
+using NEArithmeticSubtractionToS16Fixture = NEArithmeticSubtractionFixture<T1, T2, int16_t>;
+
 TEST_SUITE(S16)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
+                                                                                   framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+                                                                           framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
-               shape, data_type, policy)
+               shape, data_type1, data_type2, policy)
 {
     // Create tensors
-    Tensor ref_src1 = create_tensor<Tensor>(shape, data_type);
-    Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::S16);
+    Tensor ref_src1 = create_tensor<Tensor>(shape, data_type1);
+    Tensor ref_src2 = create_tensor<Tensor>(shape, data_type2);
     Tensor dst      = create_tensor<Tensor>(shape, DataType::S16);
 
     // Create and Configure function
@@ -124,27 +135,86 @@
     validate(dst.info()->padding(), padding);
 }
 
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+TEST_SUITE(S16_S16_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                   framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
 TEST_SUITE_END()
 
-template <typename T>
-using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
+TEST_SUITE(U8_U8_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionU8U8S16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                                        ArithmeticSubtractionU8U8S16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(S16_U8_S16)
+using NEAriSubS16U8ToS16Fixture = NEArithmeticSubtractionToS16Fixture<int16_t, uint8_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(U8_S16_S16)
+using NEAriSubU8S16ToS16Fixture = NEArithmeticSubtractionToS16Fixture<uint8_t, int16_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(Quantized)
 TEST_SUITE(QS8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -152,7 +222,8 @@
     validate(Accessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -171,7 +242,8 @@
     validate(Accessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS16Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 15)))
 {