Create MI tests for Type Conversion: CAST

* Add exclusion regex's to conformance generation

Signed-off-by: James Ward <james.ward@arm.com>
Change-Id: I15bef7451efd5662065060242d35bd7fa3381487
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index 05f6db8..817d0b6 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -5,6 +5,7 @@
 import itertools
 import json
 import logging
+import re
 from pathlib import Path
 from typing import Any
 from typing import Dict
@@ -129,20 +130,28 @@
         test_dir: Path,
         config: Dict[str, Dict[str, List[Any]]],
         negative=False,
-        exclude_types=None,
     ):
         """Initialise the selection parameters for an operator.
 
-        test_dir: the directory where the tests for all operators can be found
+        test_dir: the directory where the tests for all operators can
+            be found
         config: a dictionary with:
-                "params" - a dictionary with mappings of parameter names to the values
-                    to select (a sub-set of expected values for instance)
+                "params" - a dictionary with mappings of parameter
+                    names to the values to select (a sub-set of
+                    expected values for instance)
                 "permutes" - a list of parameter names to be permuted
-                "preselected" - a list of dictionaries containing parameter names and
-                    pre-chosen values
-                "sparsity" - a dictionary of parameter names with a sparsity value
-                "errorifs" - list of ERRORIF case names to be selected (negative test)
-        negative: bool indicating if negative testing is being selected (ERRORIF tests)
+                "preselected" - a list of dictionaries containing
+                    parameter names and pre-chosen values
+                "sparsity" - a dictionary of parameter names with a
+                    sparsity value
+                "exclude_patterns" - a list of regex's whereby each
+                    match will not be considered for selection.
+                    Exclusion happens BEFORE test selection (i.e.
+                    before permutes are applied).
+                "errorifs" - list of ERRORIF case names to be selected
+                    (negative test)
+        negative: bool indicating if negative testing is being selected
+            (ERRORIF tests)
 
         EXAMPLE CONFIG:
             "params": {
@@ -165,6 +174,9 @@
                     "pad": "pad00"
                 }
             ],
+            "exclude_patterns": [
+                ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)"
+            ],
             "errorifs": [
                 "InputZeroPointNotZero"
             ]
@@ -187,23 +199,34 @@
             )
             config["permutes"] = []
             config["preselected"] = {}
+            config["exclude_patterns"] = []
 
         self.params = config["params"] if "params" in config else {}
         self.permutes = config["permutes"] if "permutes" in config else []
         self.sparsity = config["sparsity"] if "sparsity" in config else {}
         self.preselected = config["preselected"] if "preselected" in config else {}
+        self.exclude_patterns = (
+            config["exclude_patterns"] if "exclude_patterns" in config else []
+        )
         self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
         logger.info(f"{self.name}: permutes={self.permutes}")
         logger.info(f"{self.name}: non_permutes={self.non_permutes}")
+        logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
 
-        if exclude_types is None:
-            exclude_types = []
-        self.test_paths = [
-            p
-            for p in self.get_test_paths(test_dir, self.negative)
-            # exclusion of types if requested
-            if self.path_params(p)["type"] not in exclude_types
-        ]
+        self.test_paths = []
+        excluded_paths = []
+        for path in self.get_test_paths(test_dir, self.negative):
+            pattern_match = False
+            for pattern in self.exclude_patterns:
+                if re.fullmatch(pattern, path.name):
+                    excluded_paths.append(path)
+                    pattern_match = True
+                    break
+            if not pattern_match:
+                self.test_paths.append(path)
+
+        logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
+
         if not self.test_paths:
             logger.error(f"no tests found for {self.name} in {test_dir}")
         logger.debug(f"{self.name}: paths={self.test_paths}")
@@ -861,9 +884,7 @@
     for op_name in Operator.registry:
         if not args.operators or op_name in args.operators:
             op_params = config[op_name] if op_name in config else {}
-            op = Operator.registry[op_name](
-                args.test_dir, op_params, negative, exclude_types=["float"]
-            )
+            op = Operator.registry[op_name](args.test_dir, op_params, negative)
             for test_path in op.select_tests():
                 print(test_path.resolve() if args.full_path else test_path.name)
 
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 4cf2b57..f31fa71 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -216,6 +216,82 @@
             "tosa-mi"
         ]
     },
+    "cast": {
+        "group": "type_conversion",
+        "generator_negative_dim_range": "1,10",
+        "generator_args": [
+            [
+                "--target-dtype",
+                "fp32",
+                "--target-dtype",
+                "fp16",
+                "--target-dtype",
+                "bf16",
+                "--target-dtype",
+                "int8",
+                "--target-dtype",
+                "int16",
+                "--target-dtype",
+                "int32",
+                "--fp-values-range",
+                "-2.0,2.0",
+                "--tensor-dim-range",
+                "16,64",
+                "--target-rank",
+                "1",
+                "--target-rank",
+                "2",
+                "--target-rank",
+                "3"
+            ],
+            [
+                "--target-dtype",
+                "fp32",
+                "--target-dtype",
+                "fp16",
+                "--target-dtype",
+                "bf16",
+                "--target-dtype",
+                "int8",
+                "--target-dtype",
+                "int16",
+                "--target-dtype",
+                "int32",
+                "--fp-values-range",
+                "-2.0,2.0",
+                "--tensor-dim-range",
+                "1,16",
+                "--target-rank",
+                "4",
+                "--target-rank",
+                "5"
+            ],
+            [
+                "--target-dtype",
+                "fp16",
+                "--target-shape",
+                "1,1,1,65533,1",
+                "--target-shape",
+                "2,65538,1,1"
+            ]
+        ],
+        "params": {
+            "shape": [],
+            "type": [],
+            "output_type": []
+        },
+        "permutes": [
+            "shape",
+            "type",
+            "output_type"
+        ],
+        "exclude_patterns": [
+            ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)"
+        ],
+        "profile": [
+            "tosa-mi"
+        ]
+    },
     "ceil": {
         "group": "ew_unary",
         "generator_args": [
diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py
index 817b242..4971fb0 100644
--- a/verif/conformance/tosa_verif_conformance_generator.py
+++ b/verif/conformance/tosa_verif_conformance_generator.py
@@ -34,13 +34,11 @@
     "tosa-bi": {
         "operator_test_params": "tosa_base_profile_ops_info.json",
         "framework_tests": "tosa_base_profile_framework_ops_info.json",
-        "exclude_types": [],
     },
     "tosa-mi": {
         # Note: This is just the extra tests not in the base profile!
         "operator_test_params": "tosa_main_profile_ops_info.json",
         "framework_tests": "tosa_main_profile_framework_ops_info.json",
-        "exclude_types": [],
     },
 }
 PROFILES_ALL = "all"
@@ -164,7 +162,6 @@
 def _check_to_include_test(profile, test_name, exclude_negative_tests=False):
     """Check test name for exclusions, return False to indicate excluded."""
     excludes = ["ERRORIF"] if exclude_negative_tests else []
-    excludes.extend(PROFILE_OPS_INFO[profile]["exclude_types"])
 
     for exclusion in excludes:
         if f"_{exclusion}_" in test_name:
@@ -338,7 +335,6 @@
             op_build_dir,
             op_params,
             negative,
-            exclude_types=PROFILE_OPS_INFO[profile]["exclude_types"],
         )
     except KeyError:
         logger.error(f"{operator} operator is not supported by test_select")
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index fed91f6..05a7d2b 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1445,19 +1445,40 @@
         if error_name == ErrorIf.WrongOutputType:
             dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
         elif inDtype == DType.INT8:
-            dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
+            dtypeList = [
+                DType.BOOL,
+                DType.INT16,
+                DType.INT32,
+                DType.FP16,
+                DType.BF16,
+                DType.FP32,
+            ]
         elif inDtype == DType.INT16:
-            dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
+            dtypeList = [
+                DType.BOOL,
+                DType.INT8,
+                DType.INT32,
+                DType.FP16,
+                DType.BF16,
+                DType.FP32,
+            ]
         elif inDtype == DType.INT32:
-            dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
+            dtypeList = [
+                DType.BOOL,
+                DType.INT8,
+                DType.INT16,
+                DType.FP16,
+                DType.BF16,
+                DType.FP32,
+            ]
         elif inDtype == DType.BOOL:
             dtypeList = [DType.INT8, DType.INT16, DType.INT32]
         elif inDtype == DType.FP16:
-            dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+            dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
         elif inDtype == DType.BF16:
-            dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+            dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
         elif inDtype == DType.FP32:
-            dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+            dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
         elif error_name == ErrorIf.WrongInputType:
             # Pick some potentially correct output type for incorrect input type
             dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 40c5d13..93f975d 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -314,12 +314,14 @@
 
     @staticmethod
     def eiCastErrorIf(testGen, input_dtype):
-        if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
-            outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
+        if input_dtype in [DType.BOOL, DType.FP32]:
+            outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+        elif input_dtype in [DType.FP16, DType.BF16]:
+            outputDType = [DType.BOOL, DType.INT48]
         elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
             outputDType = [DType.INT48]
         else:
-            assert True, f"input_dtype ({input_dtype}) not supported"
+            assert False, f"input_dtype ({input_dtype}) not supported"
         return outputDType
 
 
@@ -538,15 +540,24 @@
                     )
                     or (
                         input_dtype == DType.FP16
-                        and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+                        and output_dtype
+                        not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
                     )
                     or (
                         input_dtype == DType.BF16
-                        and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+                        and output_dtype
+                        not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
                     )
                     or (
                         input_dtype == DType.FP32
-                        and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+                        and output_dtype
+                        not in [
+                            DType.INT8,
+                            DType.INT16,
+                            DType.INT32,
+                            DType.FP16,
+                            DType.BF16,
+                        ]
                     )
                 ):
                     error_result = True