Add extensions & profiles support to conformance generator

Support existing conformance profiles to ease transition
New combined config tosa_ext_profile_ops_info.json that supports
extension selection

Signed-off-by: Won Jeon <won.jeon@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ic04696a01d619d191b9c8abc4ef7f4e8b86c52ca
diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py
index 433a33f..59e88bb 100644
--- a/verif/conformance/tosa_verif_conformance_generator.py
+++ b/verif/conformance/tosa_verif_conformance_generator.py
@@ -26,6 +26,7 @@
 
 import conformance.model_files as cmf
 from conformance.test_select import Operator
+from conformance.tosa_profiles import TosaProfiles
 from convert2conformance.convert2conformance import main as c2c_main
 from convert2conformance.convert2conformance import OUTPUT_TYPE_DEFAULT
 from convert2conformance.convert2conformance import OUTPUT_TYPES
@@ -47,7 +48,7 @@
         "framework_tests": "tosa_main_profile_framework_ops_info.json",
     },
 }
-PROFILES_ALL = "all"
+PROFILES_EXTENSIONS_ALL = "all"
 
 DEFAULT_SEED = 42
 
@@ -96,10 +97,22 @@
     return (rc.stdout, rc.stderr)
 
 
+def _supports_for_enabled(profile_ext):
+    # The "supports_for" part of the config only works for MI and related extensions
+    # TODO - Update with TosaBI etc in future
+    return profile_ext in (
+        TosaProfiles.TosaMI,
+        TosaProfiles.TosaExtFP8E4M3,
+        TosaProfiles.TosaExtFP8E5M2,
+        TosaProfiles.TosaExtBF16,
+        TosaProfiles.TosaExtFFT,
+    )
+
+
 def build_op_tests(
     args,
     test_type,
-    profile,
+    profile_ext,
     operator,
     group,
     gen_args_list,
@@ -115,7 +128,7 @@
     Returns operator output directory
     """
     build_tests_cmd = "tosa_verif_build_tests"
-    op_build_dir = args.build_dir / profile / group
+    op_build_dir = args.build_dir / profile_ext / group
 
     if gen_filter is None:
         gen_filter = f"^{operator}$"
@@ -131,19 +144,19 @@
         "--seed",
         str(args.random_seed),
     ]
-
     if args.verbosity:
         build_cmd_base.append("-" + ("v" * args.verbosity))
 
     if args.tests_list_file is not None:
         build_cmd_base.append("--list-tests")
 
-    if "lazy_data_gen" in supports and args.lazy_data_generation:
-        build_cmd_base.append("--lazy-data-generation")
-    if "stable_random_gen" in supports and not args.global_random_generation:
-        build_cmd_base.append("--stable-random-generation")
-    if "random_const_inputs" in supports:
-        build_cmd_base.append("--random-const-inputs")
+    if _supports_for_enabled(profile_ext):
+        if "lazy_data_gen" in supports and args.lazy_data_generation:
+            build_cmd_base.append("--lazy-data-generation")
+        if "stable_random_gen" in supports and not args.global_random_generation:
+            build_cmd_base.append("--stable-random-generation")
+        if "random_const_inputs" in supports:
+            build_cmd_base.append("--random-const-inputs")
 
     if "generator_select" in supports:
         if selector_info is None:
@@ -252,11 +265,14 @@
     return tests
 
 
-def generate_results(args, profile, operator, op_build_dir, supports=[], tests=None):
+def generate_results(
+    args, profile_ext, operator, op_build_dir, supports=[], tests=None
+):
     """Run tests on reference model and save result to the test directory."""
-    if "lazy_data_gen" in supports and args.lazy_data_generation:
-        logger.info("Skipping running tests due to lazy data gen")
-        return
+    if _supports_for_enabled(profile_ext):
+        if "lazy_data_gen" in supports and args.lazy_data_generation:
+            logger.info("Skipping running tests due to lazy data gen")
+            return
 
     num_cores = args.num_cores
 
@@ -320,11 +336,11 @@
 def convert_tests(
     args,
     test_type,
-    profile,
+    profile_ext,
     operator,
     op_build_dir,
     output_dir,
-    op_profiles_list,
+    op_profiles_extensions_list,
     supports=[],
     tests=None,
     group=None,
@@ -341,22 +357,23 @@
     c2c_args_base.extend(["--output-type", args.output_type])
     # This op maybe in more than one profile - e.g. tosa_bi and tosa_mi
     # even if we are only producing tests for tosa_mi
-    for op_profile in op_profiles_list:
+    for op_profile in op_profiles_extensions_list:
         c2c_args_base.extend(["--profile", op_profile])
     if tags is not None:
         for tag in tags:
             c2c_args_base.extend(["--tag", tag])
     if args.framework_schema:
         c2c_args_base.extend(["--framework-schema", str(args.framework_schema)])
-    if "lazy_data_gen" in supports and args.lazy_data_generation:
-        c2c_args_base.append("--lazy-data-generation")
+    if _supports_for_enabled(profile_ext):
+        if "lazy_data_gen" in supports and args.lazy_data_generation:
+            c2c_args_base.append("--lazy-data-generation")
     c2c_args_base.append("--output-directory")
 
     c2c_args_list = []
 
     if not tests:
         tests = _get_all_tests_list(test_type, op_build_dir, operator)
-        logger.info(f"Converting all {profile} profile tests of type {test_type}")
+        logger.info(f"Converting all {profile_ext} profile tests of type {test_type}")
 
     # Controls if we copy the tests in their operator sub-directory or not
     output_dir_relative_pos = -1 if trim_op_subdir else -2
@@ -408,7 +425,7 @@
 
 def get_op_tests_selection(
     args,
-    profile,
+    profile_ext,
     operator,
     op_build_dir,
     selection_config,
@@ -418,7 +435,11 @@
     """Use test picker to get subsection of tests generated."""
     # Need a full copy of the config as the selector updates it
     config = copy.deepcopy(selection_config)
-    logger.info("Choosing {} tests".format(("negative" if negative else "positive")))
+    logger.info(
+        "Choosing {} tests for {}".format(
+            ("negative" if negative else "positive"), profile_ext
+        )
+    )
     try:
         op = Operator.registry[operator](
             op_build_dir, config, negative=negative, ignore_missing=ignore_missing
@@ -517,15 +538,16 @@
 def parse_args(argv=None):
     """Parse the arguments."""
     parser = argparse.ArgumentParser()
-    profiles = list(PROFILE_OPS_INFO.keys())
-    profiles.append(PROFILES_ALL)
+    profiles = TosaProfiles.profiles()
+    profiles.append(PROFILES_EXTENSIONS_ALL)
     parser.add_argument(
         "--profile",
         dest="profile",
         choices=profiles,
-        default=profiles[0],
+        default=[TosaProfiles.TosaBI],
         type=str,
-        help=f"TOSA profile (default is {profiles[0]})",
+        nargs="*",
+        help=f"TOSA profile (default is {TosaProfiles.TosaBI})",
     )
     parser.add_argument(
         "--operators",
@@ -535,6 +557,15 @@
         help="The operator(s) to create tests for, if not supplied all tests will be created",
     )
     parser.add_argument(
+        "--extension",
+        dest="extension",
+        choices=TosaProfiles.extensions() + [PROFILES_EXTENSIONS_ALL],
+        default=[],
+        type=str,
+        nargs="*",
+        help="TOSA extension(s) to create tests for, if not supplied all tests will be created",
+    )
+    parser.add_argument(
         "--unit-tests",
         dest="unit_tests",
         choices=["operator", "framework", "both"],
@@ -658,6 +689,13 @@
         help=f"Test parameters (ops info) JSON file directory (default is {script_dir})",
     )
     parser.add_argument(
+        "--test-params-json-config",
+        "--config",
+        dest="param_config",
+        type=Path,
+        help="Test parameters (ops info) JSON file (overrides --test-param-json-directory)",
+    )
+    parser.add_argument(
         "--convert-all-tests",
         action="store_true",
         help="Converts all tests instead of those picked by test_select",
@@ -739,6 +777,9 @@
 
     args.param_json_dir = args.param_json_dir.absolute()
 
+    if args.param_config is not None:
+        args.param_config = args.param_config.absolute()
+
     if args.unit_tests in ["framework", "both"]:
         logger.warning(
             "DEPRECATION - Framework tests are not part of TOSA conformance testing"
@@ -827,86 +868,40 @@
 
     # TODO: For tosa-mi should really generate tosa-bi profile as well
     # - for now leave it as subset instead of as superset (for testing)
-    if args.profile == PROFILES_ALL:
-        profiles = list(PROFILE_OPS_INFO.keys())
+    if PROFILES_EXTENSIONS_ALL in args.profile:
+        profiles = TosaProfiles.profiles()
     else:
-        profiles = [args.profile]
+        profiles = args.profile
+
+    if PROFILES_EXTENSIONS_ALL in args.extension:
+        extensions = TosaProfiles.extensions()
+    else:
+        extensions = args.extension
+    profileExtList = profiles + extensions
+    profileExtDone = []
 
     try:
-        for profile in profiles:
-            print(f"Creating conformance tests for TOSA {profile} profile")
+        for profile_ext in profileExtList:
             # Framework unit tests
             if args.unit_tests in ["framework", "both"]:
-                logger.debug("Creating FRAMEWORK unit tests")
-                test_picks_file = (
-                    args.param_json_dir / PROFILE_OPS_INFO[profile]["framework_tests"]
-                )
-                try:
-                    with open(test_picks_file, "r") as fd:
-                        test_picks = json.load(fd)
-                except Exception as e:
-                    logger.error(
-                        f"Couldn't load framework tests info - {test_picks_file}: {e}"
-                    )
-                    return 1
-
-                operators = args.operators
-                if not operators:
-                    # Create tests for all the operators
-                    operators = list(test_picks.keys())
-
-                root_output_dir = (
-                    args.output_dir / "frameworks" / "tflite" / "operators"
-                )
-                for op in operators:
-                    logger.info(f"FRAMEWORK OP: {op}")
-                    if op not in test_picks:
-                        logger.warning(
-                            f"Framework op {op} not found in {test_picks_file} - skipping"
-                        )
-                        continue
-
-                    op_profiles_list = test_picks[op]["profile"]
-                    if (
-                        args.profile != PROFILES_ALL
-                        and args.profile not in op_profiles_list
-                    ):
-                        # Skip this operator as not part of the profile chosen
-                        logger.debug(f"Skipping {op} as not part of {args.profile}")
-                        continue
-
-                    logger.debug(f"Copying and renaming {op}")
-                    framework_test_dir = copy_rename_framework_tests(
-                        args, op, test_picks
-                    )
-
-                    if args.convert_all_tests:
-                        logger.debug("Running and converting all framework tests")
-                        framework_tests = None  # Don't select any
-                    else:
-                        logger.debug("Running and converting selected framework tests")
-                        framework_tests = get_framework_tests_selection(
-                            args, op, test_picks, framework_test_dir
-                        )
-                    convert_tests(
-                        args,
-                        "positive",
-                        profile,
-                        op,
-                        framework_test_dir,
-                        root_output_dir,
-                        op_profiles_list,
-                        tests=framework_tests,
-                        trim_op_subdir=True,
-                    )
+                logger.error("Framework test support has been removed")
 
             # Operator unit tests
             if args.unit_tests in ["operator", "both"]:
                 logger.debug("Creating OPERATOR unit tests")
-                test_params_file = (
-                    args.param_json_dir
-                    / PROFILE_OPS_INFO[profile]["operator_test_params"]
-                )
+                if args.param_config is None:
+                    # Fall back to old method
+                    if profile_ext in PROFILE_OPS_INFO:
+                        config = PROFILE_OPS_INFO[profile_ext]["operator_test_params"]
+                        test_params_file = args.param_json_dir / config
+                    else:
+                        logger.error(
+                            "Extensions not supported in old conformance configs - skipping"
+                        )
+                        continue
+                else:
+                    test_params_file = args.param_config
+
                 try:
                     with open(test_params_file, "r") as fd:
                         test_params = json.load(fd)
@@ -922,6 +917,10 @@
                     # Create tests for all the operators
                     operators = list(test_params.keys())
 
+                print(
+                    f"Creating conformance tests for TOSA {profile_ext} profile/extension"
+                )
+
                 for op in operators:
                     logger.info(f"OPERATOR: {op}")
                     if op not in test_params:
@@ -930,30 +929,49 @@
                         )
                         continue
 
-                    op_profiles_list = test_params[op]["profile"]
-                    if (
-                        args.profile != PROFILES_ALL
-                        and args.profile not in op_profiles_list
-                    ):
-                        # Skip this operator as not part of the profile chosen
-                        logger.debug(f"Skipping {op} as not part of {args.profile}")
-                        continue
-
                     operator_group = test_params[op]["group"]
                     root_output_dir = args.output_dir / "operators"
-                    supports = (
-                        test_params[op]["support_for"]
-                        if "support_for" in test_params[op]
-                        else []
-                    )
-                    gen_filter = (
-                        test_params[op]["gen_filter"]
-                        if "gen_filter" in test_params[op]
-                        else None
-                    )
+                    supports = test_params[op].get("support_for", [])
+                    gen_filter = test_params[op].get("gen_filter", None)
+                    old_profile_info = test_params[op].get("profile", [])
 
                     # Iterate through the generation groups selecting tests from each
                     for gen_name, gen_dict in test_params[op]["generation"].items():
+                        supports_any = gen_dict.get("supports_any", [])
+                        supports_all = gen_dict.get("supports_all", [])
+
+                        # Fall back for old configs
+                        if not supports_all and not supports_any:
+                            if not old_profile_info:
+                                logger.error(
+                                    f"generator {gen_name} for {op} is missing supports_all/supports_any"
+                                )
+                                raise (GenConformanceError())
+                            else:
+                                supports_any = old_profile_info
+
+                        supported = supports_any + supports_all
+
+                        if profile_ext not in supported:
+                            logger.info(
+                                f"No match for profile/extension {profile_ext} for generation group {gen_name} - skipping"
+                            )
+                            continue
+
+                        if any(p in supported for p in profileExtDone):
+                            logger.info(
+                                f"Already used this generator {gen_name} before - skipping"
+                            )
+                            continue
+
+                        if profile_ext not in supports_any and not (
+                            len(supports_all) > 0
+                            and all(p in profileExtList for p in supports_all)
+                        ):
+                            logger.info(
+                                f"Profile/extension {profile_ext} is not in {supports_any} or the profiles/extensions chosen do not meet all the requirements of {supports_all} - skipping"
+                            )
+                            continue
 
                         if not in_version(args.test_version, gen_dict):
                             logger.warning(
@@ -993,16 +1011,23 @@
                                 selector_name = "default"
                         else:
                             selector_name = "default"
+
                         if selector_name not in test_params[op]["selection"]:
                             logger.error(
                                 f"Could not find {selector_name} in selection dict for {op}"
                             )
                             raise (GenConformanceError())
 
+                        if test_params[op]["selection"][selector_name].get(
+                            "generator_select", False
+                        ):
+                            # Extend the support to include the new test selection in the generator
+                            supports = supports + ["generator_select"]
+
                         op_build_dir = build_op_tests(
                             args,
                             test_type,
-                            profile,
+                            profile_ext,
                             op,
                             gen_name,
                             gen_dict["generator_args"],
@@ -1020,7 +1045,11 @@
                             if test_type in ["positive", "both"]:
                                 logger.info(f"Running and converting all {op} tests")
                                 generate_results(
-                                    args, profile, op, op_build_dir, supports=supports
+                                    args,
+                                    profile_ext,
+                                    op,
+                                    op_build_dir,
+                                    supports=supports,
                                 )
                             operator_test_list = None
                         else:
@@ -1053,7 +1082,7 @@
                                     tests_gen, tests_gen2 = tee(
                                         get_op_tests_selection(
                                             args,
-                                            profile,
+                                            profile_ext,
                                             op,
                                             op_build_dir,
                                             selection_config,
@@ -1062,7 +1091,7 @@
                                     )
                                 generate_results(
                                     args,
-                                    profile,
+                                    profile_ext,
                                     op,
                                     op_build_dir,
                                     supports=supports,
@@ -1075,7 +1104,7 @@
                                 operator_test_list.extend(
                                     get_op_tests_selection(
                                         args,
-                                        profile,
+                                        profile_ext,
                                         op,
                                         op_build_dir,
                                         selection_config,
@@ -1086,22 +1115,24 @@
                         tags = (
                             [gen_name] if gen_name != STANDARD_GENERATOR_GROUP else None
                         )
-
                         output_dir = convert_tests(
                             args,
                             test_type,
-                            profile,
+                            profile_ext,
                             op,
                             op_build_dir,
                             root_output_dir,
-                            op_profiles_list,
+                            supported,
                             supports=supports,
                             tests=operator_test_list,
                             group=operator_group,
                             tags=tags,
                         )
                         if not args.keep_large_files:
-                            check_op_tests(args, profile, op, output_dir)
+                            check_op_tests(args, profile_ext, op, output_dir)
+
+            profileExtDone.append(profile_ext)
+
     except GenConformanceError:
         return 1