blob: e3f1738816afe889f02c6b23b7f9537344932d44 [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01002# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
James Ward736fd1a2023-01-23 17:13:37 +00008import re
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01009from pathlib import Path
10from typing import Any
11from typing import Dict
12from typing import List
13
14logging.basicConfig()
15logger = logging.getLogger("test_select")
16
17
18def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
19 """Generate permuted combinations of a dictionary of values and combine with others.
20
21 permutes: a dictionary with sequences of values to be fully permuted
22 others: a dictionary with sequences of values not fully permuted, but all used
23
24 This yields dictionaries with one value from each of the items in permutes,
25 combined with one value from each of the items in others.
26
27 Example 1:
28
29 permutes = {"a": [1, 2], "b": [3, 4]}
30 others = {"c": [5, 6, 7], "d" [True, False]}
31
32 generates:
33
34 [
35 {"a": 1, "b": 3, "c": 5, "d": True},
36 {"a": 1, "b": 4, "c": 6, "d": False`},
37 {"a": 2, "b": 3, "c": 7, "d": True},
38 {"a": 2, "b": 4, "c": 5, "d": False`},
39 ]
40
41 Example 2:
42
43 permutes = {"a": [1, 2], "b": [3, 4]}
44 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
45
46 generates:
47
48 [
49 {"a": 1, "b": 3, "c": 5, "d": True},
50 {"a": 1, "b": 4, "c": 6, "d": False},
51 {"a": 2, "b": 3, "c": 7, "d": True},
52 {"a": 2, "b": 4, "c": 8, "d": False},
53 {"a": 1, "b": 3, "c": 9, "d": True},
54 ]
55
56 Raises:
57 ValueError if any item is in both permutes and others
58 """
59 for k in permutes:
60 if k in others:
61 raise ValueError(f"item conflict: {k}")
62
63 p_keys = []
64 p_vals = []
65 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
66 # so we return a single, empty dictionary, if others is also empty
67 p_product_len = 1
68 # extract the keys and values from the permutes dictionary
69 # and calulate the product of the number of values in each item as we do so
70 for k, v in permutes.items():
71 p_keys.append(k)
72 p_vals.append(v)
73 p_product_len *= len(v)
74 # create a cyclic generator for the product of all the permuted values
75 p_product = itertools.product(*p_vals)
76 p_generator = itertools.cycle(p_product)
77
78 o_keys = []
79 o_vals = []
80 o_generators = []
81 # extract the keys and values from the others dictionary
82 # and create a cyclic generator for each list of values
83 for k, v in others.items():
84 o_keys.append(k)
85 o_vals.append(v)
86 o_generators.append(itertools.cycle(v))
87
88 # The number of params dictionaries generated will be the maximumum size
89 # of the permuted values and the non-permuted values from others
90 max_items = max([p_product_len] + [len(x) for x in o_vals])
91
92 # create a dictionary with a single value for each of the permutes and others keys
93 for _ in range(max_items):
94 params = {}
95 # add the values for the permutes parameters
96 # the permuted values generator returns a value for each of the permuted keys
97 # in the same order as they were originally given
98 p_vals = next(p_generator)
99 for i in range(len(p_keys)):
100 params[p_keys[i]] = p_vals[i]
101 # add the values for the others parameters
102 # there is a separate generator for each of the others values
103 for i in range(len(o_keys)):
104 params[o_keys[i]] = next(o_generators[i])
105 yield params
106
107
108class Operator:
109 """Base class for operator specific selection properties."""
110
111 # A registry of all Operator subclasses, indexed by the operator name
112 registry = {}
113
114 def __init_subclass__(cls, **kwargs):
115 """Subclass initialiser to register all Operator classes."""
116 super().__init_subclass__(**kwargs)
117 cls.registry[cls.name] = cls
118
119 # Derived classes must override the operator name
120 name = None
121 # Operators with additional parameters must override the param_names
122 # NB: the order must match the order the values appear in the test names
123 param_names = ["shape", "type"]
124
125 # Working set of param_names - updated for negative tests
126 wks_param_names = None
127
128 def __init__(
129 self,
130 test_dir: Path,
131 config: Dict[str, Dict[str, List[Any]]],
132 negative=False,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000133 ignore_missing=False,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100134 ):
135 """Initialise the selection parameters for an operator.
136
James Ward736fd1a2023-01-23 17:13:37 +0000137 test_dir: the directory where the tests for all operators can
138 be found
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100139 config: a dictionary with:
James Ward736fd1a2023-01-23 17:13:37 +0000140 "params" - a dictionary with mappings of parameter
141 names to the values to select (a sub-set of
142 expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100143 "permutes" - a list of parameter names to be permuted
James Ward736fd1a2023-01-23 17:13:37 +0000144 "preselected" - a list of dictionaries containing
145 parameter names and pre-chosen values
146 "sparsity" - a dictionary of parameter names with a
147 sparsity value
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000148 "full_sparsity" - "true"/"false" to use the sparsity
149 value on permutes/params/preselected
James Ward736fd1a2023-01-23 17:13:37 +0000150 "exclude_patterns" - a list of regex's whereby each
151 match will not be considered for selection.
152 Exclusion happens BEFORE test selection (i.e.
153 before permutes are applied).
154 "errorifs" - list of ERRORIF case names to be selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000155 after exclusion (negative tests)
James Ward736fd1a2023-01-23 17:13:37 +0000156 negative: bool indicating if negative testing is being selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000157 which filters for ERRORIF in the test name and only selects
158 the first test found (ERRORIF tests)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000159 ignore_missing: bool indicating if missing tests should be ignored
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100160
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000161 EXAMPLE CONFIG (with non-json comments):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 "params": {
163 "output_type": [
164 "outi8",
165 "outb"
166 ]
167 },
168 "permutes": [
169 "shape",
170 "type"
171 ],
172 "sparsity": {
173 "pad": 15
174 },
175 "preselected": [
176 {
177 "shape": "6",
178 "type": "i8",
179 "pad": "pad00"
180 }
181 ],
James Ward736fd1a2023-01-23 17:13:37 +0000182 "exclude_patterns": [
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000183 # Exclude positive (not ERRORIF) integer tests
184 "^((?!ERRORIF).)*_(i8|i16|i32|b)_out(i8|i16|i32|b)",
185 # Exclude negative (ERRORIF) i8 test
186 ".*_ERRORIF_.*_i8_outi8"
James Ward736fd1a2023-01-23 17:13:37 +0000187 ],
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100188 "errorifs": [
189 "InputZeroPointNotZero"
190 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100191 """
192 assert isinstance(
193 self.name, str
194 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
195
196 self.negative = negative
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000197 self.ignore_missing = ignore_missing
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100198 self.wks_param_names = self.param_names.copy()
199 if self.negative:
200 # need to override positive set up - use "errorifs" config if set
201 # add in errorif case before shape to support all ops, including
202 # different ops like COND_IF and CONVnD etc
203 index = self.wks_param_names.index("shape")
204 self.wks_param_names[index:index] = ["ERRORIF", "case"]
205 config["params"] = {x: [] for x in self.wks_param_names}
206 config["params"]["case"] = (
207 config["errorifs"] if "errorifs" in config else []
208 )
209 config["permutes"] = []
210 config["preselected"] = {}
211
212 self.params = config["params"] if "params" in config else {}
213 self.permutes = config["permutes"] if "permutes" in config else []
214 self.sparsity = config["sparsity"] if "sparsity" in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000215 self.full_sparsity = (
216 (config["full_sparsity"] == "true") if "full_sparsity" in config else False
217 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100218 self.preselected = config["preselected"] if "preselected" in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000219 self.exclude_patterns = (
220 config["exclude_patterns"] if "exclude_patterns" in config else []
221 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100222 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
223 logger.info(f"{self.name}: permutes={self.permutes}")
224 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
James Ward736fd1a2023-01-23 17:13:37 +0000225 logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100226
James Ward736fd1a2023-01-23 17:13:37 +0000227 self.test_paths = []
228 excluded_paths = []
229 for path in self.get_test_paths(test_dir, self.negative):
230 pattern_match = False
231 for pattern in self.exclude_patterns:
232 if re.fullmatch(pattern, path.name):
233 excluded_paths.append(path)
234 pattern_match = True
235 break
236 if not pattern_match:
237 self.test_paths.append(path)
238
239 logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
240
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100241 if not self.test_paths:
242 logger.error(f"no tests found for {self.name} in {test_dir}")
243 logger.debug(f"{self.name}: paths={self.test_paths}")
244
245 # get default parameter values for any not given in the config
246 default_params = self.get_default_params()
247 for param in default_params:
248 if param not in self.params or not self.params[param]:
249 self.params[param] = default_params[param]
250 for param in self.wks_param_names:
251 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
252
253 @staticmethod
254 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
255 """Generate test paths for operators using operator specifics."""
256 for base_dir in sorted(test_dir.glob(base_dir_glob)):
257 for path in sorted(base_dir.glob(path_glob)):
258 if (not negative and "ERRORIF" not in str(path)) or (
259 negative and "ERRORIF" in str(path)
260 ):
Jeremy Johnson30476252023-11-20 16:15:30 +0000261 # Check for test set paths
evacha014a205112024-03-08 16:39:24 +0000262 match = re.match(r"(.*)_(s[0-9]+|full|fs)", path.name)
Jeremy Johnson30476252023-11-20 16:15:30 +0000263 if match:
evacha014a205112024-03-08 16:39:24 +0000264 if match.group(2) in ["s0", "full", "fs"]:
Jeremy Johnson30476252023-11-20 16:15:30 +0000265 # Only return the truncated test name
evacha01ad8e1e22024-03-19 12:42:17 +0000266 # of the first test of a set, and for full tests
Jeremy Johnson30476252023-11-20 16:15:30 +0000267 yield path.with_name(match.group(1))
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100268 else:
269 yield path
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100270
271 @classmethod
272 def get_test_paths(cls, test_dir: Path, negative):
273 """Generate test paths for this operator."""
274 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
275
276 def path_params(self, path):
277 """Return a dictionary of params from the test path."""
278 params = {}
279 op_name_parts = self.name.split("_")
280 values = path.name.split("_")[len(op_name_parts) :]
281 assert len(values) == len(
282 self.wks_param_names
283 ), f"len({values}) == len({self.wks_param_names})"
284 for i, param in enumerate(self.wks_param_names):
285 params[param] = values[i]
286 return params
287
288 def get_default_params(self):
289 """Get the default parameter values from the test names."""
290 params = {param: set() for param in self.wks_param_names}
291 for path in self.test_paths:
292 path_params = self.path_params(path)
293 for k in params:
294 params[k].add(path_params[k])
295 for param in params:
296 params[param] = sorted(list(params[param]))
297 return params
298
Jeremy Johnson30476252023-11-20 16:15:30 +0000299 @staticmethod
300 def _get_test_set_paths(path):
301 """Expand a path to find all the test sets."""
302 s = 0
303 paths = []
304 # Have a bound for the maximum test sets
305 while s < 100:
306 set_path = path.with_name(f"{path.name}_s{s}")
307 if set_path.exists():
308 paths.append(set_path)
309 else:
310 if s == 0:
evacha01ad8e1e22024-03-19 12:42:17 +0000311 logger.warning(f"Could not find test set 0 - {str(set_path)}")
Jeremy Johnson30476252023-11-20 16:15:30 +0000312 break
313 s += 1
314 return paths
315
evacha01ad8e1e22024-03-19 12:42:17 +0000316 @staticmethod
317 def _get_extra_test_paths(path):
318 """Expand a path to find extra tests."""
319 paths = []
evacha014a205112024-03-08 16:39:24 +0000320 for suffix in ["full", "fs"]:
evacha01ad8e1e22024-03-19 12:42:17 +0000321 suffix_path = path.with_name(f"{path.name}_{suffix}")
322 if suffix_path.exists():
323 paths.append(suffix_path)
324 return paths
325
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100326 def select_tests(self): # noqa: C901 (function too complex)
327 """Generate the paths to the selected tests for this operator."""
328 if not self.test_paths:
329 # Exit early when nothing to select from
330 return
331
332 # the test paths that have not been selected yet
333 unused_paths = set(self.test_paths)
334
335 # a list of dictionaries of unused preselected parameter combinations
336 unused_preselected = [x for x in self.preselected]
337 logger.debug(f"preselected: {unused_preselected}")
338
339 # a list of dictionaries of unused permuted parameter combinations
340 permutes = {k: self.params[k] for k in self.permutes}
341 others = {k: self.params[k] for k in self.non_permutes}
342 unused_permuted = [x for x in expand_params(permutes, others)]
343 logger.debug(f"permuted: {unused_permuted}")
344
345 # a dictionary of sets of unused parameter values
346 if self.negative:
347 # We only care about selecting a test for each errorif case
348 unused_values = {k: set() for k in self.params}
349 unused_values["case"] = set(self.params["case"])
350 else:
351 unused_values = {k: set(v) for k, v in self.params.items()}
352
353 # select tests matching permuted, or preselected, parameter combinations
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000354 for n, path in enumerate(self.test_paths):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100355 path_params = self.path_params(path)
356 if path_params in unused_permuted or path_params in unused_preselected:
357 unused_paths.remove(path)
358 if path_params in unused_preselected:
359 unused_preselected.remove(path_params)
360 if path_params in unused_permuted:
361 unused_permuted.remove(path_params)
362 if self.negative:
363 # remove any other errorif cases, so we only match one
364 for p in list(unused_permuted):
365 if p["case"] == path_params["case"]:
366 unused_permuted.remove(p)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000367 if self.full_sparsity:
368 # Test for sparsity
369 skip = False
370 for k in path_params:
371 if k in self.sparsity and n % self.sparsity[k] != 0:
372 logger.debug(f"Skipping due to {k} sparsity - {path.name}")
373 skip = True
374 break
375 if skip:
376 continue
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100377 # remove the param values used by this path
378 for k in path_params:
379 unused_values[k].discard(path_params[k])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000380 logger.debug(f"FOUND wanted: {path.name}")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100381 if path.exists():
382 yield path
383 else:
Jeremy Johnson30476252023-11-20 16:15:30 +0000384 # Must be a test set - expand to all test sets
385 for p in Operator._get_test_set_paths(path):
386 yield p
evacha01ad8e1e22024-03-19 12:42:17 +0000387 # check for extra tests
388 for p in Operator._get_extra_test_paths(path):
389 yield p
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100390
391 # search for tests that match any unused parameter values
392 for n, path in enumerate(sorted(list(unused_paths))):
393 path_params = self.path_params(path)
394 # select paths with unused param values
395 # skipping some, if sparsity is set for the param
396 for k in path_params:
397 if path_params[k] in unused_values[k] and (
398 k not in self.sparsity or n % self.sparsity[k] == 0
399 ):
400 # remove the param values used by this path
401 for p in path_params:
402 unused_values[p].discard(path_params[p])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000403 sparsity = self.sparsity[k] if k in self.sparsity else 0
404 logger.debug(f"FOUND unused [{k}/{n}/{sparsity}]: {path.name}")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100405 if path.exists():
406 yield path
407 else:
Jeremy Johnson30476252023-11-20 16:15:30 +0000408 # Must be a test set - expand to all test sets
409 for p in Operator._get_test_set_paths(path):
410 yield p
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100411 break
412
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000413 if not self.ignore_missing:
414 # report any preselected combinations that were not found
415 for params in unused_preselected:
416 logger.warning(f"MISSING preselected: {params}")
417 # report any permuted combinations that were not found
418 for params in unused_permuted:
419 logger.debug(f"MISSING permutation: {params}")
420 # report any param values that were not found
421 for k, values in unused_values.items():
422 if values:
423 if k not in self.sparsity:
424 logger.warning(
425 f"MISSING {len(values)} values for {k}: {values}"
426 )
427 else:
428 logger.info(
429 f"Skipped {len(values)} values for {k} due to sparsity setting"
430 )
431 logger.debug(f"Values skipped: {values}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100432
433
434class AbsOperator(Operator):
435 """Test selector for the ABS operator."""
436
437 name = "abs"
438
439
440class ArithmeticRightShiftOperator(Operator):
441 """Test selector for the Arithmetic Right Shift operator."""
442
443 name = "arithmetic_right_shift"
444 param_names = ["shape", "type", "rounding"]
445
446
447class AddOperator(Operator):
448 """Test selector for the ADD operator."""
449
450 name = "add"
451
452
Won Jeon74342e52024-01-09 00:34:40 +0000453class AddShapeOperator(Operator):
454 """Test selector for the ADD_SHAPE operator."""
455
456 name = "add_shape"
457
458
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100459class ArgmaxOperator(Operator):
460 """Test selector for the ARGMAX operator."""
461
462 name = "argmax"
463 param_names = ["shape", "type", "axis"]
464
465
466class AvgPool2dOperator(Operator):
467 """Test selector for the AVG_POOL2D operator."""
468
469 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100470 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100471
472
473class BitwiseAndOperator(Operator):
474 """Test selector for the BITWISE_AND operator."""
475
476 name = "bitwise_and"
477
478
479class BitwiseNotOperator(Operator):
480 """Test selector for the BITWISE_NOT operator."""
481
482 name = "bitwise_not"
483
484
485class BitwiseOrOperator(Operator):
486 """Test selector for the BITWISE_OR operator."""
487
488 name = "bitwise_or"
489
490
491class BitwiseXorOperator(Operator):
492 """Test selector for the BITWISE_XOR operator."""
493
494 name = "bitwise_xor"
495
496
497class CastOperator(Operator):
498 """Test selector for the CAST operator."""
499
500 name = "cast"
501 param_names = ["shape", "type", "output_type"]
502
503
James Ward71616fe2022-11-23 11:00:47 +0000504class CeilOperator(Operator):
505 """Test selector for the CEIL operator."""
506
507 name = "ceil"
508
509
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100510class ClampOperator(Operator):
511 """Test selector for the CLAMP operator."""
512
513 name = "clamp"
514
515
516class CLZOperator(Operator):
517 """Test selector for the CLZ operator."""
518
519 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100520
521
522class ConcatOperator(Operator):
523 """Test selector for the CONCAT operator."""
524
525 name = "concat"
526 param_names = ["shape", "type", "axis"]
527
528
Won Jeon74342e52024-01-09 00:34:40 +0000529class ConcatShapeOperator(Operator):
530 """Test selector for the CONCAT_SHAPE operator."""
531
532 name = "concat_shape"
533
534
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100535class CondIfOperator(Operator):
536 """Test selector for the COND_IF operator."""
537
538 name = "cond_if"
539 param_names = ["variant", "shape", "type", "cond"]
540
541
542class ConstOperator(Operator):
543 """Test selector for the CONST operator."""
544
545 name = "const"
546
547
Won Jeon74342e52024-01-09 00:34:40 +0000548class ConstShapeOperator(Operator):
549 """Test selector for the CONST_SHAPE operator."""
550
551 name = "const_shape"
552
553
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100554class Conv2dOperator(Operator):
555 """Test selector for the CONV2D operator."""
556
557 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100558 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100559
560
561class Conv3dOperator(Operator):
562 """Test selector for the CONV3D operator."""
563
564 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100565 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100566
567
568class DepthwiseConv2dOperator(Operator):
569 """Test selector for the DEPTHWISE_CONV2D operator."""
570
571 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100572 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100573
574
Won Jeona21b2e82023-08-10 10:33:01 +0000575class DimOeprator(Operator):
576 """Test selector for the DIM operator."""
577
578 name = "dim"
579 param_names = ["shape", "type", "axis"]
580
581
Won Jeon74342e52024-01-09 00:34:40 +0000582class DivShapeOperator(Operator):
583 """Test selector for the DIV_SHAPE operator."""
584
585 name = "div_shape"
586
587
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100588class EqualOperator(Operator):
589 """Test selector for the EQUAL operator."""
590
591 name = "equal"
592
593
Jeremy Johnson35396f22023-01-04 17:05:25 +0000594class ExpOperator(Operator):
595 """Test selector for the EXP operator."""
596
597 name = "exp"
598
599
Won Jeon78155c62023-06-10 00:20:04 +0000600class ErfOperator(Operator):
601 """Test selector for the ERF operator."""
602
603 name = "erf"
604
605
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000606class FFT2DOperator(Operator):
607 """Test selector for the FFT2D operator."""
608
609 name = "fft2d"
610 param_names = ["shape", "type", "inverse"]
611
612
James Ward71616fe2022-11-23 11:00:47 +0000613class FloorOperator(Operator):
614 """Test selector for the FLOOR operator."""
615
616 name = "floor"
617
618
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100619class FullyConnectedOperator(Operator):
620 """Test selector for the FULLY_CONNECTED operator."""
621
622 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100623 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100624
625
626class GatherOperator(Operator):
627 """Test selector for the GATHER operator."""
628
629 name = "gather"
630
631
632class GreaterOperator(Operator):
633 """Test selector for the GREATER operator."""
634
635 name = "greater"
636
637 @classmethod
638 def get_test_paths(cls, test_dir: Path, negative):
639 """Generate test paths for this operator."""
640 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
641
642
643class GreaterEqualOperator(Operator):
644 """Test selector for the GREATER_EQUAL operator."""
645
646 name = "greater_equal"
647
648
649class IdentityOperator(Operator):
650 """Test selector for the IDENTITY operator."""
651
652 name = "identity"
653
654
655class IntDivOperator(Operator):
Jeremy Johnson35396f22023-01-04 17:05:25 +0000656 """Test selector for the INTDIV operator."""
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100657
658 name = "intdiv"
659
660
Jeremy Johnson35396f22023-01-04 17:05:25 +0000661class LogOperator(Operator):
662 """Test selector for the LOG operator."""
663
664 name = "log"
665
666
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100667class LogicalAndOperator(Operator):
668 """Test selector for the LOGICAL_AND operator."""
669
670 name = "logical_and"
671
672
673class LogicalLeftShiftOperator(Operator):
674 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
675
676 name = "logical_left_shift"
677
678
679class LogicalNotOperator(Operator):
680 """Test selector for the LOGICAL_NOT operator."""
681
682 name = "logical_not"
683
684
685class LogicalOrOperator(Operator):
686 """Test selector for the LOGICAL_OR operator."""
687
688 name = "logical_or"
689
690
691class LogicalRightShiftOperator(Operator):
692 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
693
694 name = "logical_right_shift"
695
696
697class LogicalXorOperator(Operator):
698 """Test selector for the LOGICAL_XOR operator."""
699
700 name = "logical_xor"
701
702
703class MatmulOperator(Operator):
704 """Test selector for the MATMUL operator."""
705
706 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100707 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100708
709
710class MaximumOperator(Operator):
711 """Test selector for the Maximum operator."""
712
713 name = "maximum"
714
715
716class MaxPool2dOperator(Operator):
717 """Test selector for the MAX_POOL2D operator."""
718
719 name = "max_pool2d"
720 param_names = ["shape", "type", "stride", "kernel", "pad"]
721
722
723class MinimumOperator(Operator):
724 """Test selector for the Minimum operator."""
725
726 name = "minimum"
727
728
729class MulOperator(Operator):
730 """Test selector for the MUL operator."""
731
732 name = "mul"
733 param_names = ["shape", "type", "perm", "shift"]
734
735
Won Jeon74342e52024-01-09 00:34:40 +0000736class MulShapeOperator(Operator):
737 """Test selector for the MUL_SHAPE operator."""
738
739 name = "mul_shape"
740
741
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100742class NegateOperator(Operator):
743 """Test selector for the Negate operator."""
744
745 name = "negate"
746
747
748class PadOperator(Operator):
749 """Test selector for the PAD operator."""
750
751 name = "pad"
752 param_names = ["shape", "type", "pad"]
753
754
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000755class PowOperator(Operator):
756 """Test selector for the POW operator."""
757
758 name = "pow"
759
760
Jeremy Johnson35396f22023-01-04 17:05:25 +0000761class ReciprocalOperator(Operator):
762 """Test selector for the RECIPROCAL operator."""
763
764 name = "reciprocal"
765
766
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100767class ReduceAllOperator(Operator):
768 """Test selector for the REDUCE_ALL operator."""
769
770 name = "reduce_all"
771 param_names = ["shape", "type", "axis"]
772
773
774class ReduceAnyOperator(Operator):
775 """Test selector for the REDUCE_ANY operator."""
776
777 name = "reduce_any"
778 param_names = ["shape", "type", "axis"]
779
780
781class ReduceMaxOperator(Operator):
782 """Test selector for the REDUCE_MAX operator."""
783
784 name = "reduce_max"
785 param_names = ["shape", "type", "axis"]
786
787
788class ReduceMinOperator(Operator):
789 """Test selector for the REDUCE_MIN operator."""
790
791 name = "reduce_min"
792 param_names = ["shape", "type", "axis"]
793
794
James Ward512c1ca2023-01-27 18:46:44 +0000795class ReduceProductOperator(Operator):
796 """Test selector for the REDUCE_PRODUCT operator."""
797
798 name = "reduce_product"
799 param_names = ["shape", "type", "axis"]
800
801
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100802class ReduceSumOperator(Operator):
803 """Test selector for the REDUCE_SUM operator."""
804
805 name = "reduce_sum"
806 param_names = ["shape", "type", "axis"]
807
808
809class RescaleOperator(Operator):
810 """Test selector for the RESCALE operator."""
811
812 name = "rescale"
813 param_names = [
814 "shape",
815 "type",
816 "output_type",
817 "scale",
818 "double_round",
819 "per_channel",
820 ]
821
822
823class ReshapeOperator(Operator):
824 """Test selector for the RESHAPE operator."""
825
826 name = "reshape"
Jerry Ge264f7fa2023-04-21 22:49:57 +0000827 param_names = ["shape", "type", "perm", "rank", "out"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100828
829
830class ResizeOperator(Operator):
831 """Test selector for the RESIZE operator."""
832
833 name = "resize"
834 param_names = [
835 "shape",
836 "type",
837 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100838 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100839 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100840 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100841 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100842 ]
843
844
845class ReverseOperator(Operator):
846 """Test selector for the REVERSE operator."""
847
848 name = "reverse"
849 param_names = ["shape", "type", "axis"]
850
851
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000852class RFFT2DOperator(Operator):
853 """Test selector for the RFFT2D operator."""
854
855 name = "rfft2d"
856
857
Jeremy Johnson35396f22023-01-04 17:05:25 +0000858class RsqrtOperator(Operator):
859 """Test selector for the RSQRT operator."""
860
861 name = "rsqrt"
862
863
Jerry Ge51bd4f52024-02-20 11:21:19 -0800864class CosOperator(Operator):
865 """Test selector for the COS operator."""
866
867 name = "cos"
868
869
870class SinOperator(Operator):
871 """Test selector for the SIN operator."""
872
873 name = "sin"
874
875
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100876class ScatterOperator(Operator):
877 """Test selector for the SCATTER operator."""
878
879 name = "scatter"
880
881
882class SelectOperator(Operator):
883 """Test selector for the SELECT operator."""
884
885 name = "select"
886
887
James Wardb45db9a2022-12-12 13:02:44 +0000888class SigmoidOperator(Operator):
889 """Test selector for the SIGMOID operator."""
890
891 name = "sigmoid"
892
893
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100894class SliceOperator(Operator):
895 """Test selector for the SLICE operator."""
896
897 name = "slice"
898 param_names = ["shape", "type", "perm"]
899
900
901class SubOperator(Operator):
902 """Test selector for the SUB operator."""
903
904 name = "sub"
905
906
Won Jeon74342e52024-01-09 00:34:40 +0000907class SubShapeOperator(Operator):
908 """Test selector for the SUB_SHAPE operator."""
909
910 name = "sub_shape"
911
912
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100913class TableOperator(Operator):
914 """Test selector for the TABLE operator."""
915
916 name = "table"
917
918
James Wardb45db9a2022-12-12 13:02:44 +0000919class TanhOperator(Operator):
920 """Test selector for the TANH operator."""
921
922 name = "tanh"
923
924
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100925class TileOperator(Operator):
926 """Test selector for the TILE operator."""
927
928 name = "tile"
929 param_names = ["shape", "type", "perm"]
930
931
932class TransposeOperator(Operator):
933 """Test selector for the TRANSPOSE operator."""
934
935 name = "transpose"
936 param_names = ["shape", "type", "perm"]
937
938 @classmethod
939 def get_test_paths(cls, test_dir: Path, negative):
940 """Generate test paths for this operator."""
941 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
942
943
944class TransposeConv2dOperator(Operator):
945 """Test selector for the TRANSPOSE_CONV2D operator."""
946
947 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100948 param_names = [
949 "kernel",
950 "shape",
951 "type",
952 "accum_type",
953 "stride",
954 "pad",
955 "out_shape",
956 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100957
958 def path_params(self, path):
959 """Return a dictionary of params from the test path."""
960 params = super().path_params(path)
961 # out_shape is different for every test case, so ignore it for selection
962 params["out_shape"] = ""
963 return params
964
965
966class WhileLoopOperator(Operator):
967 """Test selector for the WHILE_LOOP operator."""
968
969 name = "while_loop"
970 param_names = ["shape", "type", "cond"]
971
972
973def parse_args():
974 """Parse the arguments."""
975 parser = argparse.ArgumentParser()
976 parser.add_argument(
977 "--test-dir",
978 default=Path.cwd(),
979 type=Path,
980 help=(
981 "The directory where test subdirectories for all operators can be found"
982 " (default: current working directory)"
983 ),
984 )
985 parser.add_argument(
986 "--config",
987 default=Path(__file__).with_suffix(".json"),
988 type=Path,
989 help="A JSON file defining the parameters to use for each operator",
990 )
991 parser.add_argument(
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000992 "--selector",
993 default="default",
994 type=str,
995 help="The selector in the selection dictionary to use for each operator",
996 )
997 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100998 "--full-path", action="store_true", help="output the full path for each test"
999 )
1000 parser.add_argument(
1001 "-v",
1002 dest="verbosity",
1003 action="count",
1004 default=0,
1005 help="Verbosity (can be used multiple times for more details)",
1006 )
1007 parser.add_argument(
1008 "operators",
1009 type=str,
1010 nargs="*",
1011 help=(
1012 f"Select tests for the specified operator(s)"
1013 f" - all operators are assumed if none are specified)"
1014 f" - choose from: {[n for n in Operator.registry]}"
1015 ),
1016 )
1017 parser.add_argument(
1018 "--test-type",
1019 dest="test_type",
1020 choices=["positive", "negative"],
1021 default="positive",
1022 type=str,
1023 help="type of tests selected, positive or negative",
1024 )
1025 return parser.parse_args()
1026
1027
1028def main():
1029 """Example test selection."""
1030 args = parse_args()
1031
1032 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +00001033 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001034 logger.info(f"{__file__}: args: {args}")
1035
1036 try:
1037 with open(args.config, "r") as fd:
1038 config = json.load(fd)
1039 except Exception as e:
1040 logger.error(f"Config file error: {e}")
1041 return 2
1042
1043 negative = args.test_type == "negative"
1044 for op_name in Operator.registry:
1045 if not args.operators or op_name in args.operators:
1046 op_params = config[op_name] if op_name in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001047 if "selection" in op_params and args.selector in op_params["selection"]:
1048 selection_config = op_params["selection"][args.selector]
1049 else:
1050 logger.warning(
1051 f"Could not find selection config {args.selector} for {op_name}"
1052 )
1053 selection_config = {}
1054 op = Operator.registry[op_name](args.test_dir, selection_config, negative)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001055 for test_path in op.select_tests():
1056 print(test_path.resolve() if args.full_path else test_path.name)
1057
1058 return 0
1059
1060
1061if __name__ == "__main__":
1062 exit(main())