blob: 5a13178c4dd76217585b794add6b22e43e53f25c [file] [log] [blame]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001# Copyright (c) 2024, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import copy
4import logging
5
6logging.basicConfig()
7logger = logging.getLogger("tosa_verif_build_tests")
8
9
10class Test:
11 """Test container to allow group and permute selection."""
12
13 def __init__(
14 self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None
15 ):
16 self.opName = opName
17 self.testStr = testStr
18 self.dtype = dtype
19 self.error = error
20 self.shapeList = shapeList
21 self.argsDict = argsDict
22 # Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example
23 self.testOpName = testOpName if testOpName is not None else opName
24
25 self.key = None
26 self.groupKey = None
27 self.mark = False
28
29 def __str__(self):
30 return self.testStr
31
32 def __lt__(self, other):
33 return self.testStr < str(other)
34
35 def getArg(self, param):
36 # Get parameter values (arguments) for this test
37 if param == "rank":
38 return len(self.shapeList[0])
39 elif param == "dtype":
40 if isinstance(self.dtype, list):
41 return tuple(self.dtype)
42 return self.dtype
43 elif param == "shape" and "shape" not in self.argsDict:
44 return str(self.shapeList[0])
45
46 if param in self.argsDict:
47 # Turn other args into hashable string without newlines
48 val = str(self.argsDict[param])
49 return ",".join(str(val).splitlines())
50 else:
51 return None
52
53 def setKey(self, keyParams):
54 if self.error is None:
55 # Create the main key based on primary parameters
56 key = [self.getArg(param) for param in keyParams]
57 self.key = tuple(key)
58 else:
59 # Use the error as the key
60 self.key = self.error
61 return self.key
62
63 def getKey(self):
64 return self.key
65
66 def setGroupKey(self, groupParams):
67 # Create the group key based on arguments that do not define the group
68 # Therefore this test will match other tests that have the same arguments
69 # that are NOT the group arguments (group arguments like test set number)
70 paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys()))
71 key = []
72 for param in paramsList:
73 if param in groupParams:
74 continue
75 key.append(self.getArg(param))
76 self.groupKey = tuple(key)
77 return self.groupKey
78
79 def getGroupKey(self):
80 return self.groupKey
81
82 def inGroup(self, groupKey):
83 return self.groupKey == groupKey
84
85 def setMark(self):
86 # Marks the test as important
87 self.mark = True
88
89 def getMark(self):
90 return self.mark
91
92 def isError(self):
93 return self.error is not None
94
95
96def _get_selection_info_from_op(op, selectionCriteria, item, default):
97 # Get selection info from the op
98 if (
99 "selection" in op
100 and selectionCriteria in op["selection"]
101 and item in op["selection"][selectionCriteria]
102 ):
103 return op["selection"][selectionCriteria][item]
104 else:
105 return default
106
107
108def _get_tests_by_group(tests):
109 # Create simple structures to record the tests in groups
110 groups = []
111 group_tests = {}
112
113 for test in tests:
114 key = test.getGroupKey()
115 if key in group_tests:
116 group_tests[key].append(test)
117 else:
118 group_tests[key] = [test]
119 groups.append(key)
120
121 # Return list of test groups (group keys) and a dictionary with a list of tests
122 # associated with each group key
123 return groups, group_tests
124
125
126def _get_specific_op_info(opName, opSelectionInfo, testOpName):
127 # Get the op specific section from the selection config
128 name = opName if opName in opSelectionInfo else testOpName
129 if name not in opSelectionInfo:
130 logger.info(f"No op entry found for {opName} in test selection config")
131 return {}
132 return opSelectionInfo[name]
133
134
135class TestOpList:
136 """All the tests for one op grouped by permutations."""
137
138 def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName):
139 self.opName = opName
140 self.testOpName = testOpName
141 op = _get_specific_op_info(opName, opSelectionInfo, testOpName)
142
143 # See verif/conformance/README.md for more information on
144 # these selection arguments
145 self.permuteArgs = _get_selection_info_from_op(
146 op, selectionCriteria, "permutes", ["rank", "dtype"]
147 )
148 self.paramArgs = _get_selection_info_from_op(
149 op, selectionCriteria, "full_params", []
150 )
151 self.specificArgs = _get_selection_info_from_op(
152 op, selectionCriteria, "specifics", {}
153 )
154 self.groupArgs = _get_selection_info_from_op(
155 op, selectionCriteria, "groups", ["s"]
156 )
157 self.maximumPerPermute = _get_selection_info_from_op(
158 op, selectionCriteria, "maximum", None
159 )
160 self.numErrorIfs = _get_selection_info_from_op(
161 op, selectionCriteria, "num_errorifs", 1
162 )
163 self.selectAll = _get_selection_info_from_op(
164 op, selectionCriteria, "all", False
165 )
166
167 if self.paramArgs and self.maximumPerPermute > 1:
168 logger.warning(f"Unsupported - selection params AND maximum for {opName}")
169
170 self.tests = []
171 self.testStrings = set()
172 self.shapes = set()
173
174 self.permutes = set()
175 self.testsPerPermute = {}
176 self.paramsPerPermute = {}
177 self.specificsPerPermute = {}
178
179 self.selectionDone = False
180
181 def __len__(self):
182 return len(self.tests)
183
184 def add(self, test):
185 # Add a test to this op group and set up the permutations/group for it
186 assert test.opName.startswith(self.opName)
187 if str(test) in self.testStrings:
188 logger.info(f"Skipping duplicate test: {str(test)}")
189 return
190
191 self.tests.append(test)
192 self.testStrings.add(str(test))
193
194 self.shapes.add(test.getArg("shape"))
195
196 # Work out the permutation key for this test
197 permute = test.setKey(self.permuteArgs)
198 # Set up the group key for the test (for pulling out groups during selection)
199 test.setGroupKey(self.groupArgs)
200
201 if permute not in self.permutes:
202 # New permutation
203 self.permutes.add(permute)
204 # Set up area to record the selected tests
205 self.testsPerPermute[permute] = []
206 if self.paramArgs:
207 # Set up area to record the unique test params found
208 self.paramsPerPermute[permute] = {}
209 for param in self.paramArgs:
210 self.paramsPerPermute[permute][param] = set()
211 # Set up copy of the specific test args for selecting these
212 self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs)
213
214 def _init_select(self):
215 # Can only perform the selection process once as it alters the permute
216 # information set at init
217 assert not self.selectionDone
218
219 # Count of non-specific tests added to each permute (not error)
220 if not self.selectAll:
221 countPerPermute = {permute: 0 for permute in self.permutes}
222
223 # Go through each test looking for permutes, unique params & specifics
224 for test in self.tests:
225 permute = test.getKey()
226 append = False
227 possible_append = False
228
229 if test.isError():
230 # Error test, choose up to number of tests
231 if len(self.testsPerPermute[permute]) < self.numErrorIfs:
232 append = True
233 else:
234 if self.selectAll:
235 append = True
236 else:
237 # See if this is a specific test to add
238 for param, values in self.specificsPerPermute[permute].items():
239 arg = test.getArg(param)
240 # Iterate over a copy of the values, so we can remove them from the original
241 if arg in values.copy():
242 # Found a match, remove it, so we don't look for it later
243 values.remove(arg)
244 # Mark the test as special (and so shouldn't be removed)
245 test.setMark()
246 append = True
247
248 if self.paramArgs:
249 # See if this test contains any new params we should keep
250 # Perform this check even if we have already selected the test
251 # so we can record the params found
252 for param in self.paramArgs:
253 arg = test.getArg(param)
254 if arg not in self.paramsPerPermute[permute][param]:
255 # We have found a new value for this arg, record it
256 self.paramsPerPermute[permute][param].add(arg)
257 possible_append = True
258 else:
259 # No params set, so possible test to add up to maximum
260 possible_append = True
261
262 if (not append and possible_append) and (
263 self.maximumPerPermute is None
264 or countPerPermute[permute] < self.maximumPerPermute
265 ):
266 # Not selected but could be added and we have space left if
267 # a maximum is set.
268 append = True
269 countPerPermute[permute] += 1
270
271 # Check for grouping with chosen tests
272 if not append:
273 # We will keep any tests together than form a group
274 key = test.getGroupKey()
275 for t in self.testsPerPermute[permute]:
276 if t.getGroupKey() == key:
277 if t.getMark():
278 test.setMark()
279 append = True
280
281 if append:
282 self.testsPerPermute[permute].append(test)
283
284 self.selectionDone = True
285
286 def select(self, rng=None):
287 # Create selection of tests with optional shuffle
288 if not self.selectionDone:
289 if rng:
290 rng.shuffle(self.tests)
291
292 self._init_select()
293
294 # Now create the full list of selected tests per permute
295 selection = []
296
297 for permute, tests in self.testsPerPermute.items():
298 selection.extend(tests)
299
300 return selection
301
302 def all(self):
303 # Un-selected list of tests - i.e. all of them
304 return self.tests
305
306
307class TestList:
308 """List of all tests grouped by operator."""
309
310 def __init__(self, opSelectionInfo, selectionCriteria="default"):
311 self.opLists = {}
312 self.opSelectionInfo = opSelectionInfo
313 self.selectionCriteria = selectionCriteria
314
315 def __len__(self):
316 length = 0
317 for opName in self.opLists.keys():
318 length += len(self.opLists[opName])
319 return length
320
321 def add(self, test):
322 if test.opName not in self.opLists:
323 self.opLists[test.opName] = TestOpList(
324 test.opName,
325 self.opSelectionInfo,
326 self.selectionCriteria,
327 test.testOpName,
328 )
329 self.opLists[test.opName].add(test)
330
331 def _get_tests(self, selectMode, rng):
332 selection = []
333
334 for opList in self.opLists.values():
335 if selectMode:
336 tests = opList.select(rng=rng)
337 else:
338 tests = opList.all()
339 selection.extend(tests)
340
341 selection = sorted(selection)
342 return selection
343
344 def select(self, rng=None):
345 return self._get_tests(True, rng)
346
347 def all(self):
348 return self._get_tests(False, None)