blob: 904d90f09a59b035b9a15d153bb886867e0f13c6 [file] [log] [blame]
# Copyright (c) 2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import copy
import logging
logging.basicConfig()
logger = logging.getLogger("tosa_verif_build_tests")
class Test:
"""Test container to allow group and permute selection."""
def __init__(
self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None
):
self.opName = opName
self.testStr = testStr
self.dtype = dtype
self.error = error
self.shapeList = shapeList
self.argsDict = argsDict
# Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example
self.testOpName = testOpName if testOpName is not None else opName
self.key = None
self.groupKey = None
self.mark = False
def __str__(self):
return self.testStr
def __lt__(self, other):
return self.testStr < str(other)
def getArg(self, param):
# Get parameter values (arguments) for this test
if param == "rank":
return len(self.shapeList[0])
elif param == "dtype":
if isinstance(self.dtype, list):
return tuple(self.dtype)
return self.dtype
elif param == "shape" and "shape" not in self.argsDict:
return str(self.shapeList[0])
if param in self.argsDict:
# Turn other args into hashable string without newlines
val = str(self.argsDict[param])
return ",".join(str(val).splitlines())
else:
return None
def setKey(self, keyParams):
if self.error is None:
# Create the main key based on primary parameters
key = [self.getArg(param) for param in keyParams]
self.key = tuple(key)
else:
# Use the error as the key
self.key = self.error
return self.key
def getKey(self):
return self.key
def setGroupKey(self, groupParams):
# Create the group key based on arguments that do not define the group
# Therefore this test will match other tests that have the same arguments
# that are NOT the group arguments (group arguments like test set number)
paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys()))
key = []
for param in paramsList:
if param in groupParams:
continue
key.append(self.getArg(param))
self.groupKey = tuple(key)
return self.groupKey
def getGroupKey(self):
return self.groupKey
def inGroup(self, groupKey):
return self.groupKey == groupKey
def setMark(self):
# Marks the test as important
self.mark = True
def getMark(self):
return self.mark
def isError(self):
return self.error is not None
def _get_selection_info_from_op(op, selectionCriteria, item, default):
# Get selection info from the op
if (
"selection" in op
and selectionCriteria in op["selection"]
and item in op["selection"][selectionCriteria]
):
return op["selection"][selectionCriteria][item]
else:
return default
def _get_tests_by_group(tests):
# Create simple structures to record the tests in groups
groups = []
group_tests = {}
for test in tests:
key = test.getGroupKey()
if key in group_tests:
group_tests[key].append(test)
else:
group_tests[key] = [test]
groups.append(key)
# Return list of test groups (group keys) and a dictionary with a list of tests
# associated with each group key
return groups, group_tests
def _get_specific_op_info(opName, opSelectionInfo, testOpName):
# Get the op specific section from the selection config
name = opName if opName in opSelectionInfo else testOpName
if name not in opSelectionInfo:
logger.info(f"No op entry found for {opName} in test selection config")
return {}
return opSelectionInfo[name]
class TestOpList:
"""All the tests for one op grouped by permutations."""
def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName):
self.opName = opName
self.testOpName = testOpName
op = _get_specific_op_info(opName, opSelectionInfo, testOpName)
# See verif/conformance/README.md for more information on
# these selection arguments
self.permuteArgs = _get_selection_info_from_op(
op, selectionCriteria, "permutes", ["rank", "dtype"]
)
self.paramArgs = _get_selection_info_from_op(
op, selectionCriteria, "full_params", []
)
self.specificArgs = _get_selection_info_from_op(
op, selectionCriteria, "specifics", {}
)
self.groupArgs = _get_selection_info_from_op(
op, selectionCriteria, "groups", ["s"]
)
self.maximumPerPermute = _get_selection_info_from_op(
op, selectionCriteria, "maximum", None
)
self.numErrorIfs = _get_selection_info_from_op(
op, selectionCriteria, "num_errorifs", 1
)
self.selectAll = _get_selection_info_from_op(
op, selectionCriteria, "all", False
)
if self.paramArgs and self.maximumPerPermute > 1:
logger.warning(f"Unsupported - selection params AND maximum for {opName}")
self.tests = []
self.testStrings = set()
self.shapes = set()
self.permutes = set()
self.testsPerPermute = {}
self.paramsPerPermute = {}
self.specificsPerPermute = {}
self.selectionDone = False
def __len__(self):
return len(self.tests)
def add(self, test):
# Add a test to this op group and set up the permutations/group for it
assert test.opName.startswith(self.opName)
if str(test) in self.testStrings:
logger.debug(f"Skipping duplicate test: {str(test)}")
return
self.tests.append(test)
self.testStrings.add(str(test))
self.shapes.add(test.getArg("shape"))
# Work out the permutation key for this test
permute = test.setKey(self.permuteArgs)
# Set up the group key for the test (for pulling out groups during selection)
test.setGroupKey(self.groupArgs)
if permute not in self.permutes:
# New permutation
self.permutes.add(permute)
# Set up area to record the selected tests
self.testsPerPermute[permute] = []
if self.paramArgs:
# Set up area to record the unique test params found
self.paramsPerPermute[permute] = {}
for param in self.paramArgs:
self.paramsPerPermute[permute][param] = set()
# Set up copy of the specific test args for selecting these
self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs)
def _init_select(self):
# Can only perform the selection process once as it alters the permute
# information set at init
assert not self.selectionDone
# Count of non-specific tests added to each permute (not error)
if not self.selectAll:
countPerPermute = {permute: 0 for permute in self.permutes}
# Go through each test looking for permutes, unique params & specifics
for test in self.tests:
permute = test.getKey()
append = False
possible_append = False
if test.isError():
# Error test, choose up to number of tests
if len(self.testsPerPermute[permute]) < self.numErrorIfs:
append = True
else:
if self.selectAll:
append = True
else:
# See if this is a specific test to add
for param, values in self.specificsPerPermute[permute].items():
arg = test.getArg(param)
# Iterate over a copy of the values, so we can remove them from the original
if arg in values.copy():
# Found a match, remove it, so we don't look for it later
values.remove(arg)
# Mark the test as special (and so shouldn't be removed)
test.setMark()
append = True
if self.paramArgs:
# See if this test contains any new params we should keep
# Perform this check even if we have already selected the test
# so we can record the params found
for param in self.paramArgs:
arg = test.getArg(param)
if arg not in self.paramsPerPermute[permute][param]:
# We have found a new value for this arg, record it
self.paramsPerPermute[permute][param].add(arg)
possible_append = True
else:
# No params set, so possible test to add up to maximum
possible_append = True
if (not append and possible_append) and (
self.maximumPerPermute is None
or countPerPermute[permute] < self.maximumPerPermute
):
# Not selected but could be added and we have space left if
# a maximum is set.
append = True
countPerPermute[permute] += 1
# Check for grouping with chosen tests
if not append:
# We will keep any tests together than form a group
key = test.getGroupKey()
for t in self.testsPerPermute[permute]:
if t.getGroupKey() == key:
if t.getMark():
test.setMark()
append = True
if append:
self.testsPerPermute[permute].append(test)
self.selectionDone = True
def select(self, rng=None):
# Create selection of tests with optional shuffle
if not self.selectionDone:
if rng:
rng.shuffle(self.tests)
self._init_select()
# Now create the full list of selected tests per permute
selection = []
for permute, tests in self.testsPerPermute.items():
selection.extend(tests)
return selection
def all(self):
# Un-selected list of tests - i.e. all of them
return self.tests
class TestList:
"""List of all tests grouped by operator."""
def __init__(self, opSelectionInfo, selectionCriteria="default"):
self.opLists = {}
self.opSelectionInfo = opSelectionInfo
self.selectionCriteria = selectionCriteria
def __len__(self):
length = 0
for opName in self.opLists.keys():
length += len(self.opLists[opName])
return length
def add(self, test):
if test.opName not in self.opLists:
self.opLists[test.opName] = TestOpList(
test.opName,
self.opSelectionInfo,
self.selectionCriteria,
test.testOpName,
)
self.opLists[test.opName].add(test)
def _get_tests(self, selectMode, rng):
selection = []
for opList in self.opLists.values():
if selectMode:
tests = opList.select(rng=rng)
else:
tests = opList.all()
selection.extend(tests)
if selectMode:
selection = sorted(selection)
return selection
def select(self, rng=None):
return self._get_tests(True, rng)
def all(self):
return self._get_tests(False, None)