Jeremy Johnson | af09018 | 2024-02-13 18:25:39 +0000 | [diff] [blame] | 1 | # Copyright (c) 2024, ARM Limited. |
| 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | import copy |
| 4 | import logging |
| 5 | |
| 6 | logging.basicConfig() |
| 7 | logger = logging.getLogger("tosa_verif_build_tests") |
| 8 | |
| 9 | |
| 10 | class Test: |
| 11 | """Test container to allow group and permute selection.""" |
| 12 | |
| 13 | def __init__( |
| 14 | self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None |
| 15 | ): |
| 16 | self.opName = opName |
| 17 | self.testStr = testStr |
| 18 | self.dtype = dtype |
| 19 | self.error = error |
| 20 | self.shapeList = shapeList |
| 21 | self.argsDict = argsDict |
| 22 | # Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example |
| 23 | self.testOpName = testOpName if testOpName is not None else opName |
| 24 | |
| 25 | self.key = None |
| 26 | self.groupKey = None |
| 27 | self.mark = False |
| 28 | |
| 29 | def __str__(self): |
| 30 | return self.testStr |
| 31 | |
| 32 | def __lt__(self, other): |
| 33 | return self.testStr < str(other) |
| 34 | |
| 35 | def getArg(self, param): |
| 36 | # Get parameter values (arguments) for this test |
| 37 | if param == "rank": |
| 38 | return len(self.shapeList[0]) |
| 39 | elif param == "dtype": |
| 40 | if isinstance(self.dtype, list): |
| 41 | return tuple(self.dtype) |
| 42 | return self.dtype |
| 43 | elif param == "shape" and "shape" not in self.argsDict: |
| 44 | return str(self.shapeList[0]) |
| 45 | |
| 46 | if param in self.argsDict: |
| 47 | # Turn other args into hashable string without newlines |
| 48 | val = str(self.argsDict[param]) |
| 49 | return ",".join(str(val).splitlines()) |
| 50 | else: |
| 51 | return None |
| 52 | |
| 53 | def setKey(self, keyParams): |
| 54 | if self.error is None: |
| 55 | # Create the main key based on primary parameters |
| 56 | key = [self.getArg(param) for param in keyParams] |
| 57 | self.key = tuple(key) |
| 58 | else: |
| 59 | # Use the error as the key |
| 60 | self.key = self.error |
| 61 | return self.key |
| 62 | |
| 63 | def getKey(self): |
| 64 | return self.key |
| 65 | |
| 66 | def setGroupKey(self, groupParams): |
| 67 | # Create the group key based on arguments that do not define the group |
| 68 | # Therefore this test will match other tests that have the same arguments |
| 69 | # that are NOT the group arguments (group arguments like test set number) |
| 70 | paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys())) |
| 71 | key = [] |
| 72 | for param in paramsList: |
| 73 | if param in groupParams: |
| 74 | continue |
| 75 | key.append(self.getArg(param)) |
| 76 | self.groupKey = tuple(key) |
| 77 | return self.groupKey |
| 78 | |
| 79 | def getGroupKey(self): |
| 80 | return self.groupKey |
| 81 | |
| 82 | def inGroup(self, groupKey): |
| 83 | return self.groupKey == groupKey |
| 84 | |
| 85 | def setMark(self): |
| 86 | # Marks the test as important |
| 87 | self.mark = True |
| 88 | |
| 89 | def getMark(self): |
| 90 | return self.mark |
| 91 | |
| 92 | def isError(self): |
| 93 | return self.error is not None |
| 94 | |
| 95 | |
| 96 | def _get_selection_info_from_op(op, selectionCriteria, item, default): |
| 97 | # Get selection info from the op |
| 98 | if ( |
| 99 | "selection" in op |
| 100 | and selectionCriteria in op["selection"] |
| 101 | and item in op["selection"][selectionCriteria] |
| 102 | ): |
| 103 | return op["selection"][selectionCriteria][item] |
| 104 | else: |
| 105 | return default |
| 106 | |
| 107 | |
| 108 | def _get_tests_by_group(tests): |
| 109 | # Create simple structures to record the tests in groups |
| 110 | groups = [] |
| 111 | group_tests = {} |
| 112 | |
| 113 | for test in tests: |
| 114 | key = test.getGroupKey() |
| 115 | if key in group_tests: |
| 116 | group_tests[key].append(test) |
| 117 | else: |
| 118 | group_tests[key] = [test] |
| 119 | groups.append(key) |
| 120 | |
| 121 | # Return list of test groups (group keys) and a dictionary with a list of tests |
| 122 | # associated with each group key |
| 123 | return groups, group_tests |
| 124 | |
| 125 | |
| 126 | def _get_specific_op_info(opName, opSelectionInfo, testOpName): |
| 127 | # Get the op specific section from the selection config |
| 128 | name = opName if opName in opSelectionInfo else testOpName |
| 129 | if name not in opSelectionInfo: |
| 130 | logger.info(f"No op entry found for {opName} in test selection config") |
| 131 | return {} |
| 132 | return opSelectionInfo[name] |
| 133 | |
| 134 | |
| 135 | class TestOpList: |
| 136 | """All the tests for one op grouped by permutations.""" |
| 137 | |
| 138 | def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName): |
| 139 | self.opName = opName |
| 140 | self.testOpName = testOpName |
| 141 | op = _get_specific_op_info(opName, opSelectionInfo, testOpName) |
| 142 | |
| 143 | # See verif/conformance/README.md for more information on |
| 144 | # these selection arguments |
| 145 | self.permuteArgs = _get_selection_info_from_op( |
| 146 | op, selectionCriteria, "permutes", ["rank", "dtype"] |
| 147 | ) |
| 148 | self.paramArgs = _get_selection_info_from_op( |
| 149 | op, selectionCriteria, "full_params", [] |
| 150 | ) |
| 151 | self.specificArgs = _get_selection_info_from_op( |
| 152 | op, selectionCriteria, "specifics", {} |
| 153 | ) |
| 154 | self.groupArgs = _get_selection_info_from_op( |
| 155 | op, selectionCriteria, "groups", ["s"] |
| 156 | ) |
| 157 | self.maximumPerPermute = _get_selection_info_from_op( |
| 158 | op, selectionCriteria, "maximum", None |
| 159 | ) |
| 160 | self.numErrorIfs = _get_selection_info_from_op( |
| 161 | op, selectionCriteria, "num_errorifs", 1 |
| 162 | ) |
| 163 | self.selectAll = _get_selection_info_from_op( |
| 164 | op, selectionCriteria, "all", False |
| 165 | ) |
| 166 | |
| 167 | if self.paramArgs and self.maximumPerPermute > 1: |
| 168 | logger.warning(f"Unsupported - selection params AND maximum for {opName}") |
| 169 | |
| 170 | self.tests = [] |
| 171 | self.testStrings = set() |
| 172 | self.shapes = set() |
| 173 | |
| 174 | self.permutes = set() |
| 175 | self.testsPerPermute = {} |
| 176 | self.paramsPerPermute = {} |
| 177 | self.specificsPerPermute = {} |
| 178 | |
| 179 | self.selectionDone = False |
| 180 | |
| 181 | def __len__(self): |
| 182 | return len(self.tests) |
| 183 | |
| 184 | def add(self, test): |
| 185 | # Add a test to this op group and set up the permutations/group for it |
| 186 | assert test.opName.startswith(self.opName) |
| 187 | if str(test) in self.testStrings: |
Jeremy Johnson | dd975b8 | 2024-02-28 17:29:13 +0000 | [diff] [blame] | 188 | logger.debug(f"Skipping duplicate test: {str(test)}") |
Jeremy Johnson | af09018 | 2024-02-13 18:25:39 +0000 | [diff] [blame] | 189 | return |
| 190 | |
| 191 | self.tests.append(test) |
| 192 | self.testStrings.add(str(test)) |
| 193 | |
| 194 | self.shapes.add(test.getArg("shape")) |
| 195 | |
| 196 | # Work out the permutation key for this test |
| 197 | permute = test.setKey(self.permuteArgs) |
| 198 | # Set up the group key for the test (for pulling out groups during selection) |
| 199 | test.setGroupKey(self.groupArgs) |
| 200 | |
| 201 | if permute not in self.permutes: |
| 202 | # New permutation |
| 203 | self.permutes.add(permute) |
| 204 | # Set up area to record the selected tests |
| 205 | self.testsPerPermute[permute] = [] |
| 206 | if self.paramArgs: |
| 207 | # Set up area to record the unique test params found |
| 208 | self.paramsPerPermute[permute] = {} |
| 209 | for param in self.paramArgs: |
| 210 | self.paramsPerPermute[permute][param] = set() |
| 211 | # Set up copy of the specific test args for selecting these |
| 212 | self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs) |
| 213 | |
| 214 | def _init_select(self): |
| 215 | # Can only perform the selection process once as it alters the permute |
| 216 | # information set at init |
| 217 | assert not self.selectionDone |
| 218 | |
| 219 | # Count of non-specific tests added to each permute (not error) |
| 220 | if not self.selectAll: |
| 221 | countPerPermute = {permute: 0 for permute in self.permutes} |
| 222 | |
| 223 | # Go through each test looking for permutes, unique params & specifics |
| 224 | for test in self.tests: |
| 225 | permute = test.getKey() |
| 226 | append = False |
| 227 | possible_append = False |
| 228 | |
| 229 | if test.isError(): |
| 230 | # Error test, choose up to number of tests |
| 231 | if len(self.testsPerPermute[permute]) < self.numErrorIfs: |
| 232 | append = True |
| 233 | else: |
| 234 | if self.selectAll: |
| 235 | append = True |
| 236 | else: |
| 237 | # See if this is a specific test to add |
| 238 | for param, values in self.specificsPerPermute[permute].items(): |
| 239 | arg = test.getArg(param) |
| 240 | # Iterate over a copy of the values, so we can remove them from the original |
| 241 | if arg in values.copy(): |
| 242 | # Found a match, remove it, so we don't look for it later |
| 243 | values.remove(arg) |
| 244 | # Mark the test as special (and so shouldn't be removed) |
| 245 | test.setMark() |
| 246 | append = True |
| 247 | |
| 248 | if self.paramArgs: |
| 249 | # See if this test contains any new params we should keep |
| 250 | # Perform this check even if we have already selected the test |
| 251 | # so we can record the params found |
| 252 | for param in self.paramArgs: |
| 253 | arg = test.getArg(param) |
| 254 | if arg not in self.paramsPerPermute[permute][param]: |
| 255 | # We have found a new value for this arg, record it |
| 256 | self.paramsPerPermute[permute][param].add(arg) |
| 257 | possible_append = True |
| 258 | else: |
| 259 | # No params set, so possible test to add up to maximum |
| 260 | possible_append = True |
| 261 | |
| 262 | if (not append and possible_append) and ( |
| 263 | self.maximumPerPermute is None |
| 264 | or countPerPermute[permute] < self.maximumPerPermute |
| 265 | ): |
| 266 | # Not selected but could be added and we have space left if |
| 267 | # a maximum is set. |
| 268 | append = True |
| 269 | countPerPermute[permute] += 1 |
| 270 | |
| 271 | # Check for grouping with chosen tests |
| 272 | if not append: |
| 273 | # We will keep any tests together than form a group |
| 274 | key = test.getGroupKey() |
| 275 | for t in self.testsPerPermute[permute]: |
| 276 | if t.getGroupKey() == key: |
| 277 | if t.getMark(): |
| 278 | test.setMark() |
| 279 | append = True |
| 280 | |
| 281 | if append: |
| 282 | self.testsPerPermute[permute].append(test) |
| 283 | |
| 284 | self.selectionDone = True |
| 285 | |
| 286 | def select(self, rng=None): |
| 287 | # Create selection of tests with optional shuffle |
| 288 | if not self.selectionDone: |
| 289 | if rng: |
| 290 | rng.shuffle(self.tests) |
| 291 | |
| 292 | self._init_select() |
| 293 | |
| 294 | # Now create the full list of selected tests per permute |
| 295 | selection = [] |
| 296 | |
| 297 | for permute, tests in self.testsPerPermute.items(): |
| 298 | selection.extend(tests) |
| 299 | |
| 300 | return selection |
| 301 | |
| 302 | def all(self): |
| 303 | # Un-selected list of tests - i.e. all of them |
| 304 | return self.tests |
| 305 | |
| 306 | |
| 307 | class TestList: |
| 308 | """List of all tests grouped by operator.""" |
| 309 | |
| 310 | def __init__(self, opSelectionInfo, selectionCriteria="default"): |
| 311 | self.opLists = {} |
| 312 | self.opSelectionInfo = opSelectionInfo |
| 313 | self.selectionCriteria = selectionCriteria |
| 314 | |
| 315 | def __len__(self): |
| 316 | length = 0 |
| 317 | for opName in self.opLists.keys(): |
| 318 | length += len(self.opLists[opName]) |
| 319 | return length |
| 320 | |
| 321 | def add(self, test): |
| 322 | if test.opName not in self.opLists: |
| 323 | self.opLists[test.opName] = TestOpList( |
| 324 | test.opName, |
| 325 | self.opSelectionInfo, |
| 326 | self.selectionCriteria, |
| 327 | test.testOpName, |
| 328 | ) |
| 329 | self.opLists[test.opName].add(test) |
| 330 | |
| 331 | def _get_tests(self, selectMode, rng): |
| 332 | selection = [] |
| 333 | |
| 334 | for opList in self.opLists.values(): |
| 335 | if selectMode: |
| 336 | tests = opList.select(rng=rng) |
| 337 | else: |
| 338 | tests = opList.all() |
| 339 | selection.extend(tests) |
| 340 | |
Jeremy Johnson | 979b3fc | 2024-03-13 10:18:29 +0000 | [diff] [blame] | 341 | if selectMode: |
| 342 | selection = sorted(selection) |
Jeremy Johnson | af09018 | 2024-02-13 18:25:39 +0000 | [diff] [blame] | 343 | return selection |
| 344 | |
| 345 | def select(self, rng=None): |
| 346 | return self._get_tests(True, rng) |
| 347 | |
| 348 | def all(self): |
| 349 | return self._get_tests(False, None) |