Improved test selection before test generation

Add test list output to tosa_verif_build_tests and test list
capture to file for tosa_verif_conformance_generator

Improve PAD & CONV2D test coverage for tosa-mi conformance

Change to use logging for output to hide info from test lists
Tweak verbosity levels of tosa_verif_conformance_generator

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ic29da5776b02e9ac610db6ee89d0ebfb4994e055
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index c596645..253e8ee 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1,8 +1,8 @@
 # Copyright (c) 2021-2024, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import itertools
+import logging
 import math
-import warnings
 
 import generator.tosa_utils as gtu
 import numpy as np
@@ -16,6 +16,9 @@
 # DTypeNames, DType, Op and ResizeMode are convenience variables to the
 # flatc-generated types that should be enums, but aren't
 
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
 
 class TosaQuantGen:
     """QuantizedInfo random generator helper functions.
@@ -131,8 +134,9 @@
             shift = shift + 1
 
         shift = (-shift) + scaleBits
-        # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
-        #   scaleFp, scaleBits, m, multiplier, shift))
+        logger.debug(
+            f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
+        )
 
         # Adjust multiplier such that shift is in allowed value range.
         if shift == 0:
@@ -690,8 +694,9 @@
                 # Invalid data range from low to high created due to user
                 # constraints revert to using internal ranges as they are
                 # known to work
-                msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
-                warnings.warn(msg)
+                logger.info(
+                    f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
+                )
                 data_range = (low_val, high_val)
             return data_range
         return None
@@ -1856,7 +1861,7 @@
                             if "shape" in args_dict
                             else ""
                         )
-                        print(
+                        logger.info(
                             f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
                         )
                         continue
@@ -2503,7 +2508,7 @@
                 arg_list.append((name, args_dict))
 
         if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
-            warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
+            logger.info(f"No ErrorIf test created for input shape: {shapeList[0]}")
 
         arg_list = TosaArgGen._add_data_generators(
             testGen,
@@ -2683,7 +2688,9 @@
                             remainder_w = partial_w % s[1]
                             output_h = partial_h // s[0] + 1
                             output_w = partial_w // s[1] + 1
-                            # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
+                            logger.debug(
+                                f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
+                            )
                             if (
                                 # the parameters must produce integer exact output
                                 error_name != ErrorIf.PoolingOutputShapeNonInteger
@@ -2920,7 +2927,9 @@
                             # Cap the scaling at 2^15 - 1 for scale16
                             scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
 
-                        # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
+                        logger.debug(
+                            f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
+                        )
 
                         multiplier_arr = np.int32(np.zeros(shape=[nc]))
                         shift_arr = np.int32(np.zeros(shape=[nc]))
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 7a4d0d6..3972edd 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,5 +1,6 @@
 # Copyright (c) 2021-2024, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
+import logging
 import math
 
 import numpy as np
@@ -11,6 +12,9 @@
 from tosa.Op import Op
 from tosa.ResizeMode import ResizeMode
 
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
 
 class ErrorIf(object):
     MaxDimExceeded = "MaxDimExceeded"
@@ -386,12 +390,12 @@
             if expected_result and error_result:
                 serializer.setExpectedReturnCode(2, True, desc=error_reason)
             elif error_result:  # and not expected_result
-                print(
+                logger.error(
                     f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
                     f" Expected: {error_name}, Got: {validator_name}"
                 )
             elif not expected_result:  # and not error_result
-                print(
+                logger.error(
                     f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
                     f" Expected: {error_name}"
                 )
@@ -401,7 +405,7 @@
                     if k != "op":
                         if k.endswith("dtype"):
                             v = valueToName(DType, v)
-                        print(f"  {k} = {v}")
+                        logger.error(f"  {k} = {v}")
 
         return overall_result
 
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 0f68999..e7704f1 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,6 +1,7 @@
 # Copyright (c) 2020-2024, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import json
+import logging
 import os
 from copy import deepcopy
 from datetime import datetime
@@ -27,6 +28,9 @@
 // AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
 """
 
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
 
 class TosaTestGen:
     # Maximum rank of tensor supported by test generator.
@@ -2134,6 +2138,7 @@
         double_round = args_dict["double_round"]
         per_channel = args_dict["per_channel"]
         shift_arr = args_dict["shift"]
+        multiplier_arr = args_dict["multiplier"]
 
         result_tensor = OutputShaper.typeConversionOp(
             self.ser, self.rng, val, out_dtype, error_name
@@ -2203,7 +2208,9 @@
             min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
             max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
 
-        # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
+        logger.debug(
+            f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
+        )
         if scale32 and error_name is None:
             # Make sure random values are within apply_scale_32 specification
             # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
@@ -2907,7 +2914,9 @@
             cleanRankFilter = filterDict["rankFilter"]
             cleanDtypeFilter = filterDict["dtypeFilter"]
             cleanShapeFilter = filterDict["shapeFilter"]
-            # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
+            logger.debug(
+                f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
+            )
 
             for r in cleanRankFilter:
                 for t in cleanDtypeFilter:
@@ -2981,8 +2990,7 @@
         except KeyError:
             raise Exception("Cannot find op with name {}".format(opName))
 
-        if self.args.verbose:
-            print(f"Creating {testStr}")
+        logger.info(f"Creating {testStr}")
 
         # Create a serializer
         self.createSerializer(opName, testStr)
@@ -3062,7 +3070,7 @@
             self.serialize("test", tensMeta)
         else:
             # The test is not valid
-            print(f"Invalid ERROR_IF test created: {opName} {testStr}")
+            logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
 
     def createDynamicOpLists(self):
 
@@ -3084,6 +3092,7 @@
             self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
             self.TOSA_OP_LIST[testName]["filter"] = k
             self.TOSA_OP_LIST[testName]["template"] = False
+            self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"
 
             testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
             self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
@@ -3091,6 +3100,7 @@
             ].copy()
             self.TOSA_OP_LIST[testName]["filter"] = k
             self.TOSA_OP_LIST[testName]["template"] = False
+            self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"
 
             testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
             self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
@@ -3098,12 +3108,14 @@
             ].copy()
             self.TOSA_OP_LIST[testName]["filter"] = k
             self.TOSA_OP_LIST[testName]["template"] = False
+            self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"
 
         for k in KERNELS_3D:
             testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
             self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
             self.TOSA_OP_LIST[testName]["filter"] = k
             self.TOSA_OP_LIST[testName]["template"] = False
+            self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"
 
         # Delete any templates after having created any dynamic ops
         # This is a two-pass operation because it's bad practice to delete
diff --git a/verif/generator/tosa_test_select.py b/verif/generator/tosa_test_select.py
new file mode 100644
index 0000000..5a13178
--- /dev/null
+++ b/verif/generator/tosa_test_select.py
@@ -0,0 +1,348 @@
+# Copyright (c) 2024, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import copy
+import logging
+
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
+
+class Test:
+    """Test container to allow group and permute selection."""
+
+    def __init__(
+        self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None
+    ):
+        self.opName = opName
+        self.testStr = testStr
+        self.dtype = dtype
+        self.error = error
+        self.shapeList = shapeList
+        self.argsDict = argsDict
+        # Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example
+        self.testOpName = testOpName if testOpName is not None else opName
+
+        self.key = None
+        self.groupKey = None
+        self.mark = False
+
+    def __str__(self):
+        return self.testStr
+
+    def __lt__(self, other):
+        return self.testStr < str(other)
+
+    def getArg(self, param):
+        # Get parameter values (arguments) for this test
+        if param == "rank":
+            return len(self.shapeList[0])
+        elif param == "dtype":
+            if isinstance(self.dtype, list):
+                return tuple(self.dtype)
+            return self.dtype
+        elif param == "shape" and "shape" not in self.argsDict:
+            return str(self.shapeList[0])
+
+        if param in self.argsDict:
+            # Turn other args into hashable string without newlines
+            val = str(self.argsDict[param])
+            return ",".join(str(val).splitlines())
+        else:
+            return None
+
+    def setKey(self, keyParams):
+        if self.error is None:
+            # Create the main key based on primary parameters
+            key = [self.getArg(param) for param in keyParams]
+            self.key = tuple(key)
+        else:
+            # Use the error as the key
+            self.key = self.error
+        return self.key
+
+    def getKey(self):
+        return self.key
+
+    def setGroupKey(self, groupParams):
+        # Create the group key based on arguments that do not define the group
+        # Therefore this test will match other tests that have the same arguments
+        # that are NOT the group arguments (group arguments like test set number)
+        paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys()))
+        key = []
+        for param in paramsList:
+            if param in groupParams:
+                continue
+            key.append(self.getArg(param))
+        self.groupKey = tuple(key)
+        return self.groupKey
+
+    def getGroupKey(self):
+        return self.groupKey
+
+    def inGroup(self, groupKey):
+        return self.groupKey == groupKey
+
+    def setMark(self):
+        # Marks the test as important
+        self.mark = True
+
+    def getMark(self):
+        return self.mark
+
+    def isError(self):
+        return self.error is not None
+
+
+def _get_selection_info_from_op(op, selectionCriteria, item, default):
+    # Get selection info from the op
+    if (
+        "selection" in op
+        and selectionCriteria in op["selection"]
+        and item in op["selection"][selectionCriteria]
+    ):
+        return op["selection"][selectionCriteria][item]
+    else:
+        return default
+
+
+def _get_tests_by_group(tests):
+    # Create simple structures to record the tests in groups
+    groups = []
+    group_tests = {}
+
+    for test in tests:
+        key = test.getGroupKey()
+        if key in group_tests:
+            group_tests[key].append(test)
+        else:
+            group_tests[key] = [test]
+            groups.append(key)
+
+    # Return list of test groups (group keys) and a dictionary with a list of tests
+    # associated with each group key
+    return groups, group_tests
+
+
+def _get_specific_op_info(opName, opSelectionInfo, testOpName):
+    # Get the op specific section from the selection config
+    name = opName if opName in opSelectionInfo else testOpName
+    if name not in opSelectionInfo:
+        logger.info(f"No op entry found for {opName} in test selection config")
+        return {}
+    return opSelectionInfo[name]
+
+
+class TestOpList:
+    """All the tests for one op grouped by permutations."""
+
+    def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName):
+        self.opName = opName
+        self.testOpName = testOpName
+        op = _get_specific_op_info(opName, opSelectionInfo, testOpName)
+
+        # See verif/conformance/README.md for more information on
+        # these selection arguments
+        self.permuteArgs = _get_selection_info_from_op(
+            op, selectionCriteria, "permutes", ["rank", "dtype"]
+        )
+        self.paramArgs = _get_selection_info_from_op(
+            op, selectionCriteria, "full_params", []
+        )
+        self.specificArgs = _get_selection_info_from_op(
+            op, selectionCriteria, "specifics", {}
+        )
+        self.groupArgs = _get_selection_info_from_op(
+            op, selectionCriteria, "groups", ["s"]
+        )
+        self.maximumPerPermute = _get_selection_info_from_op(
+            op, selectionCriteria, "maximum", None
+        )
+        self.numErrorIfs = _get_selection_info_from_op(
+            op, selectionCriteria, "num_errorifs", 1
+        )
+        self.selectAll = _get_selection_info_from_op(
+            op, selectionCriteria, "all", False
+        )
+
+        if self.paramArgs and self.maximumPerPermute > 1:
+            logger.warning(f"Unsupported - selection params AND maximum for {opName}")
+
+        self.tests = []
+        self.testStrings = set()
+        self.shapes = set()
+
+        self.permutes = set()
+        self.testsPerPermute = {}
+        self.paramsPerPermute = {}
+        self.specificsPerPermute = {}
+
+        self.selectionDone = False
+
+    def __len__(self):
+        return len(self.tests)
+
+    def add(self, test):
+        # Add a test to this op group and set up the permutations/group for it
+        assert test.opName.startswith(self.opName)
+        if str(test) in self.testStrings:
+            logger.info(f"Skipping duplicate test: {str(test)}")
+            return
+
+        self.tests.append(test)
+        self.testStrings.add(str(test))
+
+        self.shapes.add(test.getArg("shape"))
+
+        # Work out the permutation key for this test
+        permute = test.setKey(self.permuteArgs)
+        # Set up the group key for the test (for pulling out groups during selection)
+        test.setGroupKey(self.groupArgs)
+
+        if permute not in self.permutes:
+            # New permutation
+            self.permutes.add(permute)
+            # Set up area to record the selected tests
+            self.testsPerPermute[permute] = []
+            if self.paramArgs:
+                # Set up area to record the unique test params found
+                self.paramsPerPermute[permute] = {}
+                for param in self.paramArgs:
+                    self.paramsPerPermute[permute][param] = set()
+            # Set up copy of the specific test args for selecting these
+            self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs)
+
+    def _init_select(self):
+        # Can only perform the selection process once as it alters the permute
+        # information set at init
+        assert not self.selectionDone
+
+        # Count of non-specific tests added to each permute (not error)
+        if not self.selectAll:
+            countPerPermute = {permute: 0 for permute in self.permutes}
+
+        # Go through each test looking for permutes, unique params & specifics
+        for test in self.tests:
+            permute = test.getKey()
+            append = False
+            possible_append = False
+
+            if test.isError():
+                # Error test, choose up to number of tests
+                if len(self.testsPerPermute[permute]) < self.numErrorIfs:
+                    append = True
+            else:
+                if self.selectAll:
+                    append = True
+                else:
+                    # See if this is a specific test to add
+                    for param, values in self.specificsPerPermute[permute].items():
+                        arg = test.getArg(param)
+                        # Iterate over a copy of the values, so we can remove them from the original
+                        if arg in values.copy():
+                            # Found a match, remove it, so we don't look for it later
+                            values.remove(arg)
+                            # Mark the test as special (and so shouldn't be removed)
+                            test.setMark()
+                            append = True
+
+                    if self.paramArgs:
+                        # See if this test contains any new params we should keep
+                        # Perform this check even if we have already selected the test
+                        # so we can record the params found
+                        for param in self.paramArgs:
+                            arg = test.getArg(param)
+                            if arg not in self.paramsPerPermute[permute][param]:
+                                # We have found a new value for this arg, record it
+                                self.paramsPerPermute[permute][param].add(arg)
+                                possible_append = True
+                    else:
+                        # No params set, so possible test to add up to maximum
+                        possible_append = True
+
+                    if (not append and possible_append) and (
+                        self.maximumPerPermute is None
+                        or countPerPermute[permute] < self.maximumPerPermute
+                    ):
+                        # Not selected but could be added and we have space left if
+                        # a maximum is set.
+                        append = True
+                        countPerPermute[permute] += 1
+
+            # Check for grouping with chosen tests
+            if not append:
+                # We will keep any tests together than form a group
+                key = test.getGroupKey()
+                for t in self.testsPerPermute[permute]:
+                    if t.getGroupKey() == key:
+                        if t.getMark():
+                            test.setMark()
+                        append = True
+
+            if append:
+                self.testsPerPermute[permute].append(test)
+
+        self.selectionDone = True
+
+    def select(self, rng=None):
+        # Create selection of tests with optional shuffle
+        if not self.selectionDone:
+            if rng:
+                rng.shuffle(self.tests)
+
+            self._init_select()
+
+        # Now create the full list of selected tests per permute
+        selection = []
+
+        for permute, tests in self.testsPerPermute.items():
+            selection.extend(tests)
+
+        return selection
+
+    def all(self):
+        # Un-selected list of tests - i.e. all of them
+        return self.tests
+
+
+class TestList:
+    """List of all tests grouped by operator."""
+
+    def __init__(self, opSelectionInfo, selectionCriteria="default"):
+        self.opLists = {}
+        self.opSelectionInfo = opSelectionInfo
+        self.selectionCriteria = selectionCriteria
+
+    def __len__(self):
+        length = 0
+        for opName in self.opLists.keys():
+            length += len(self.opLists[opName])
+        return length
+
+    def add(self, test):
+        if test.opName not in self.opLists:
+            self.opLists[test.opName] = TestOpList(
+                test.opName,
+                self.opSelectionInfo,
+                self.selectionCriteria,
+                test.testOpName,
+            )
+        self.opLists[test.opName].add(test)
+
+    def _get_tests(self, selectMode, rng):
+        selection = []
+
+        for opList in self.opLists.values():
+            if selectMode:
+                tests = opList.select(rng=rng)
+            else:
+                tests = opList.all()
+            selection.extend(tests)
+
+        selection = sorted(selection)
+        return selection
+
+    def select(self, rng=None):
+        return self._get_tests(True, rng)
+
+    def all(self):
+        return self._get_tests(False, None)
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 8012d93..c32993a 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -1,17 +1,23 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import argparse
+import json
+import logging
 import re
 import sys
 from pathlib import Path
 
 import conformance.model_files as cmf
+import generator.tosa_test_select as tts
 from generator.tosa_test_gen import TosaTestGen
 from serializer.tosa_serializer import dtype_str_to_val
 from serializer.tosa_serializer import DTypeNames
 
 OPTION_FP_VALUES_RANGE = "--fp-values-range"
 
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
 
 # Used for parsing a comma-separated list of integers/floats in a string
 # to an actual list of integers/floats with special case max
@@ -58,6 +64,7 @@
 
     parser = argparse.ArgumentParser()
 
+    filter_group = parser.add_argument_group("test filter options")
     ops_group = parser.add_argument_group("operator options")
     tens_group = parser.add_argument_group("tensor options")
 
@@ -73,7 +80,7 @@
         help="Random seed for test generation",
     )
 
-    parser.add_argument(
+    filter_group.add_argument(
         "--filter",
         dest="filter",
         default="",
@@ -82,7 +89,12 @@
     )
 
     parser.add_argument(
-        "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
+        "-v",
+        "--verbose",
+        dest="verbose",
+        action="count",
+        default=0,
+        help="Verbose operation",
     )
 
     parser.add_argument(
@@ -226,7 +238,7 @@
         help="Allow constant input tensors for concat operator",
     )
 
-    parser.add_argument(
+    filter_group.add_argument(
         "--test-type",
         dest="test_type",
         choices=["positive", "negative", "both"],
@@ -235,6 +247,26 @@
         help="type of tests produced, positive, negative, or both",
     )
 
+    filter_group.add_argument(
+        "--test-selection-config",
+        dest="selection_config",
+        type=Path,
+        help="enables test selection, this is the path to the JSON test selection config file, will use the default selection specified for each op unless --selection-criteria is supplied",
+    )
+
+    filter_group.add_argument(
+        "--test-selection-criteria",
+        dest="selection_criteria",
+        help="enables test selection, this is the selection criteria to use from the selection config",
+    )
+
+    parser.add_argument(
+        "--list-tests",
+        dest="list_tests",
+        action="store_true",
+        help="lists the tests that will be generated and then exits",
+    )
+
     ops_group.add_argument(
         "--allow-pooling-and-conv-oversizes",
         dest="oversize",
@@ -281,6 +313,10 @@
 
     args = parseArgs(argv)
 
+    loglevels = (logging.WARNING, logging.INFO, logging.DEBUG)
+    loglevel = loglevels[min(args.verbose, len(loglevels) - 1)]
+    logger.setLevel(loglevel)
+
     if not args.lazy_data_gen:
         if args.generate_lib_path is None:
             args.generate_lib_path = cmf.find_tosa_file(
@@ -290,55 +326,98 @@
             print(
                 f"Argument error: Generate library (--generate-lib-path) not found - {str(args.generate_lib_path)}"
             )
-            exit(2)
+            return 2
 
     ttg = TosaTestGen(args)
 
+    # Determine if test selection mode is enabled or not
+    selectionMode = (
+        args.selection_config is not None or args.selection_criteria is not None
+    )
+    selectionCriteria = (
+        "default" if args.selection_criteria is None else args.selection_criteria
+    )
+    if args.selection_config is not None:
+        # Try loading the selection config
+        if not args.generate_lib_path.is_file():
+            print(
+                f"Argument error: Test selection config (--test-selection-config) not found {str(args.selection_config)}"
+            )
+            return 2
+        with args.selection_config.open("r") as fd:
+            selectionCfg = json.load(fd)
+    else:
+        # Fallback to using anything defined in the TosaTestGen list
+        # by default this will mean only selecting a tests using a
+        # permutation of rank by type for each op
+        selectionCfg = ttg.TOSA_OP_LIST
+
     if args.test_type == "both":
         testType = ["positive", "negative"]
     else:
         testType = [args.test_type]
+
     results = []
     for test_type in testType:
-        testList = []
+        testList = tts.TestList(selectionCfg, selectionCriteria=selectionCriteria)
         try:
             for opName in ttg.TOSA_OP_LIST:
                 if re.match(args.filter + ".*", opName):
-                    testList.extend(
-                        ttg.genOpTestList(
-                            opName,
-                            shapeFilter=args.target_shapes,
-                            rankFilter=args.target_ranks,
-                            dtypeFilter=args.target_dtypes,
-                            testType=test_type,
-                        )
+                    tests = ttg.genOpTestList(
+                        opName,
+                        shapeFilter=args.target_shapes,
+                        rankFilter=args.target_ranks,
+                        dtypeFilter=args.target_dtypes,
+                        testType=test_type,
                     )
+                    for testOpName, testStr, dtype, error, shapeList, argsDict in tests:
+                        if "real_name" in ttg.TOSA_OP_LIST[testOpName]:
+                            name = ttg.TOSA_OP_LIST[testOpName]["real_name"]
+                        else:
+                            name = testOpName
+                        test = tts.Test(
+                            name, testStr, dtype, error, shapeList, argsDict, testOpName
+                        )
+                        testList.add(test)
         except Exception as e:
-            print(f"INTERNAL ERROR: Failure generating test lists for {opName}")
+            logger.error(f"INTERNAL ERROR: Failure generating test lists for {opName}")
             raise e
 
-        print("{} matching {} tests".format(len(testList), test_type))
+        if not selectionMode:
+            # Allow all tests to be selected
+            tests = testList.all()
+        else:
+            # Use the random number generator to shuffle the test list
+            # and select the per op tests from it
+            tests = testList.select(ttg.rng)
 
-        testStrings = []
+        if args.list_tests:
+            for test in tests:
+                print(test)
+            continue
+
+        print(f"{len(tests)} matching {test_type} tests")
+
         try:
-            for opName, testStr, dtype, error, shapeList, argsDict in testList:
-                # Check for and skip duplicate tests
-                if testStr in testStrings:
-                    print(f"Skipping duplicate test: {testStr}")
-                    continue
-                else:
-                    testStrings.append(testStr)
-
+            for test in tests:
+                opName = test.testOpName
                 results.append(
                     ttg.serializeTest(
-                        opName, testStr, dtype, error, shapeList, argsDict
+                        opName,
+                        str(test),
+                        test.dtype,
+                        test.error,
+                        test.shapeList,
+                        test.argsDict,
                     )
                 )
         except Exception as e:
-            print(f"INTERNAL ERROR: Failure creating test output for {opName}")
+            logger.error(f"INTERNAL ERROR: Failure creating test output for {opName}")
             raise e
 
-    print(f"Done creating {len(results)} tests")
+    if not args.list_tests:
+        print(f"Done creating {len(results)} tests")
+    return 0
 
 
 if __name__ == "__main__":