COMPMID-653 - Arithmetic Subtraction, add support different datatype

Change-Id: I2b3d65c8d8a85ad67b9972713d06f047f5bcd1ae
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/93693
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/tests/validation/CL/ArithmeticSubtraction.cpp b/tests/validation/CL/ArithmeticSubtraction.cpp
index 817a31f..5e7d741 100644
--- a/tests/validation/CL/ArithmeticSubtraction.cpp
+++ b/tests/validation/CL/ArithmeticSubtraction.cpp
@@ -47,8 +47,14 @@
 const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
                                                     framework::dataset::make("DataType",
                                                                              DataType::U8));
-const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
+const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::S16)),
                                                      framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
+                                                         framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)),
+                                                          framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)),
+                                                          framework::dataset::make("DataType", DataType::S16));
 const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)),
                                                      framework::dataset::make("DataType", DataType::QS8));
 const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)),
@@ -62,8 +68,8 @@
 TEST_SUITE(CL)
 TEST_SUITE(ArithmeticSubtraction)
 
-template <typename T>
-using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(U8)
 DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
@@ -89,7 +95,8 @@
     validate(dst.info()->padding(), padding);
 }
 
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                                     ArithmeticSubtractionU8Dataset),
                                                                                                                      framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
@@ -97,14 +104,19 @@
 }
 TEST_SUITE_END()
 
+template <typename T1, typename T2 = T1>
+using CLArithmeticSubtractionToS16Fixture = CLArithmeticSubtractionFixture<T1, T2, int16_t>;
+
 TEST_SUITE(S16)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
+                                                                                   framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+                                                                           framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
                                                                    framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
-               shape, data_type, policy)
+               shape, data_type1, data_type2, policy)
 {
     // Create tensors
-    CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type);
-    CLTensor ref_src2 = create_tensor<CLTensor>(shape, DataType::S16);
+    CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type1);
+    CLTensor ref_src2 = create_tensor<CLTensor>(shape, data_type2);
     CLTensor dst      = create_tensor<CLTensor>(shape, DataType::S16);
 
     // Create and Configure function
@@ -121,28 +133,86 @@
     validate(ref_src2.info()->padding(), padding);
     validate(dst.info()->padding(), padding);
 }
-
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+TEST_SUITE(S16_S16_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
-                                                                                                                   framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 TEST_SUITE_END()
 
-template <typename T>
-using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
+TEST_SUITE(U8_U8_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionU8U8S16Dataset),
+                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                                        ArithmeticSubtractionU8U8S16Dataset),
+                                                                                                                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(S16_U8_S16)
+using CLAriSubS16U8ToS16Fixture = CLArithmeticSubtractionToS16Fixture<int16_t, uint8_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionS16U8S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(U8_S16_S16)
+using CLAriSubU8S16ToS16Fixture = CLArithmeticSubtractionToS16Fixture<uint8_t, int16_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+                                                                                                               ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                       framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+                                                                                                             ArithmeticSubtractionU8S16S16Dataset),
+                                                                                                     framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
 
 TEST_SUITE(Quantized)
 TEST_SUITE(QS8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -150,7 +220,8 @@
     validate(CLAccessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS8Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 7)))
 {
@@ -169,7 +240,8 @@
     validate(CLAccessor(_target), _reference);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+                       ArithmeticSubtractionQS16Dataset),
                        framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
                        framework::dataset::make("FractionalBits", 1, 15)))
 {