blob: 55eef586db47419daaa30c1d61a5bef84b1ee0bf [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01002# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
James Ward736fd1a2023-01-23 17:13:37 +00008import re
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01009from pathlib import Path
10from typing import Any
11from typing import Dict
12from typing import List
13
14logging.basicConfig()
15logger = logging.getLogger("test_select")
16
17
18def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
19 """Generate permuted combinations of a dictionary of values and combine with others.
20
21 permutes: a dictionary with sequences of values to be fully permuted
22 others: a dictionary with sequences of values not fully permuted, but all used
23
24 This yields dictionaries with one value from each of the items in permutes,
25 combined with one value from each of the items in others.
26
27 Example 1:
28
29 permutes = {"a": [1, 2], "b": [3, 4]}
30 others = {"c": [5, 6, 7], "d" [True, False]}
31
32 generates:
33
34 [
35 {"a": 1, "b": 3, "c": 5, "d": True},
36 {"a": 1, "b": 4, "c": 6, "d": False`},
37 {"a": 2, "b": 3, "c": 7, "d": True},
38 {"a": 2, "b": 4, "c": 5, "d": False`},
39 ]
40
41 Example 2:
42
43 permutes = {"a": [1, 2], "b": [3, 4]}
44 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
45
46 generates:
47
48 [
49 {"a": 1, "b": 3, "c": 5, "d": True},
50 {"a": 1, "b": 4, "c": 6, "d": False},
51 {"a": 2, "b": 3, "c": 7, "d": True},
52 {"a": 2, "b": 4, "c": 8, "d": False},
53 {"a": 1, "b": 3, "c": 9, "d": True},
54 ]
55
56 Raises:
57 ValueError if any item is in both permutes and others
58 """
59 for k in permutes:
60 if k in others:
61 raise ValueError(f"item conflict: {k}")
62
63 p_keys = []
64 p_vals = []
65 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
66 # so we return a single, empty dictionary, if others is also empty
67 p_product_len = 1
68 # extract the keys and values from the permutes dictionary
69 # and calulate the product of the number of values in each item as we do so
70 for k, v in permutes.items():
71 p_keys.append(k)
72 p_vals.append(v)
73 p_product_len *= len(v)
74 # create a cyclic generator for the product of all the permuted values
75 p_product = itertools.product(*p_vals)
76 p_generator = itertools.cycle(p_product)
77
78 o_keys = []
79 o_vals = []
80 o_generators = []
81 # extract the keys and values from the others dictionary
82 # and create a cyclic generator for each list of values
83 for k, v in others.items():
84 o_keys.append(k)
85 o_vals.append(v)
86 o_generators.append(itertools.cycle(v))
87
88 # The number of params dictionaries generated will be the maximumum size
89 # of the permuted values and the non-permuted values from others
90 max_items = max([p_product_len] + [len(x) for x in o_vals])
91
92 # create a dictionary with a single value for each of the permutes and others keys
93 for _ in range(max_items):
94 params = {}
95 # add the values for the permutes parameters
96 # the permuted values generator returns a value for each of the permuted keys
97 # in the same order as they were originally given
98 p_vals = next(p_generator)
99 for i in range(len(p_keys)):
100 params[p_keys[i]] = p_vals[i]
101 # add the values for the others parameters
102 # there is a separate generator for each of the others values
103 for i in range(len(o_keys)):
104 params[o_keys[i]] = next(o_generators[i])
105 yield params
106
107
108class Operator:
109 """Base class for operator specific selection properties."""
110
111 # A registry of all Operator subclasses, indexed by the operator name
112 registry = {}
113
114 def __init_subclass__(cls, **kwargs):
115 """Subclass initialiser to register all Operator classes."""
116 super().__init_subclass__(**kwargs)
117 cls.registry[cls.name] = cls
118
119 # Derived classes must override the operator name
120 name = None
121 # Operators with additional parameters must override the param_names
122 # NB: the order must match the order the values appear in the test names
123 param_names = ["shape", "type"]
124
125 # Working set of param_names - updated for negative tests
126 wks_param_names = None
127
128 def __init__(
129 self,
130 test_dir: Path,
131 config: Dict[str, Dict[str, List[Any]]],
132 negative=False,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000133 ignore_missing=False,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100134 ):
135 """Initialise the selection parameters for an operator.
136
James Ward736fd1a2023-01-23 17:13:37 +0000137 test_dir: the directory where the tests for all operators can
138 be found
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100139 config: a dictionary with:
James Ward736fd1a2023-01-23 17:13:37 +0000140 "params" - a dictionary with mappings of parameter
141 names to the values to select (a sub-set of
142 expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100143 "permutes" - a list of parameter names to be permuted
James Ward736fd1a2023-01-23 17:13:37 +0000144 "preselected" - a list of dictionaries containing
145 parameter names and pre-chosen values
146 "sparsity" - a dictionary of parameter names with a
147 sparsity value
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000148 "full_sparsity" - "true"/"false" to use the sparsity
149 value on permutes/params/preselected
James Ward736fd1a2023-01-23 17:13:37 +0000150 "exclude_patterns" - a list of regex's whereby each
151 match will not be considered for selection.
152 Exclusion happens BEFORE test selection (i.e.
153 before permutes are applied).
154 "errorifs" - list of ERRORIF case names to be selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000155 after exclusion (negative tests)
James Ward736fd1a2023-01-23 17:13:37 +0000156 negative: bool indicating if negative testing is being selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000157 which filters for ERRORIF in the test name and only selects
158 the first test found (ERRORIF tests)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000159 ignore_missing: bool indicating if missing tests should be ignored
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100160
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000161 EXAMPLE CONFIG (with non-json comments):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 "params": {
163 "output_type": [
164 "outi8",
165 "outb"
166 ]
167 },
168 "permutes": [
169 "shape",
170 "type"
171 ],
172 "sparsity": {
173 "pad": 15
174 },
175 "preselected": [
176 {
177 "shape": "6",
178 "type": "i8",
179 "pad": "pad00"
180 }
181 ],
James Ward736fd1a2023-01-23 17:13:37 +0000182 "exclude_patterns": [
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000183 # Exclude positive (not ERRORIF) integer tests
184 "^((?!ERRORIF).)*_(i8|i16|i32|b)_out(i8|i16|i32|b)",
185 # Exclude negative (ERRORIF) i8 test
186 ".*_ERRORIF_.*_i8_outi8"
James Ward736fd1a2023-01-23 17:13:37 +0000187 ],
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100188 "errorifs": [
189 "InputZeroPointNotZero"
190 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100191 """
192 assert isinstance(
193 self.name, str
194 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
195
196 self.negative = negative
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000197 self.ignore_missing = ignore_missing
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100198 self.wks_param_names = self.param_names.copy()
199 if self.negative:
200 # need to override positive set up - use "errorifs" config if set
201 # add in errorif case before shape to support all ops, including
202 # different ops like COND_IF and CONVnD etc
203 index = self.wks_param_names.index("shape")
204 self.wks_param_names[index:index] = ["ERRORIF", "case"]
205 config["params"] = {x: [] for x in self.wks_param_names}
206 config["params"]["case"] = (
207 config["errorifs"] if "errorifs" in config else []
208 )
209 config["permutes"] = []
210 config["preselected"] = {}
211
212 self.params = config["params"] if "params" in config else {}
213 self.permutes = config["permutes"] if "permutes" in config else []
214 self.sparsity = config["sparsity"] if "sparsity" in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000215 self.full_sparsity = (
216 (config["full_sparsity"] == "true") if "full_sparsity" in config else False
217 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100218 self.preselected = config["preselected"] if "preselected" in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000219 self.exclude_patterns = (
220 config["exclude_patterns"] if "exclude_patterns" in config else []
221 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100222 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
223 logger.info(f"{self.name}: permutes={self.permutes}")
224 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
James Ward736fd1a2023-01-23 17:13:37 +0000225 logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100226
James Ward736fd1a2023-01-23 17:13:37 +0000227 self.test_paths = []
228 excluded_paths = []
229 for path in self.get_test_paths(test_dir, self.negative):
230 pattern_match = False
231 for pattern in self.exclude_patterns:
232 if re.fullmatch(pattern, path.name):
233 excluded_paths.append(path)
234 pattern_match = True
235 break
236 if not pattern_match:
237 self.test_paths.append(path)
238
239 logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
240
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100241 if not self.test_paths:
242 logger.error(f"no tests found for {self.name} in {test_dir}")
243 logger.debug(f"{self.name}: paths={self.test_paths}")
244
245 # get default parameter values for any not given in the config
246 default_params = self.get_default_params()
247 for param in default_params:
248 if param not in self.params or not self.params[param]:
249 self.params[param] = default_params[param]
250 for param in self.wks_param_names:
251 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
252
253 @staticmethod
254 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
255 """Generate test paths for operators using operator specifics."""
256 for base_dir in sorted(test_dir.glob(base_dir_glob)):
257 for path in sorted(base_dir.glob(path_glob)):
258 if (not negative and "ERRORIF" not in str(path)) or (
259 negative and "ERRORIF" in str(path)
260 ):
Jeremy Johnson30476252023-11-20 16:15:30 +0000261 # Check for test set paths
262 match = re.match(r"(.*)_s([0-9]+)", path.name)
263 if match:
264 if match.group(2) == "0":
265 # Only return the truncated test name
266 # of the first test of a set
267 yield path.with_name(match.group(1))
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100268 else:
269 yield path
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100270
271 @classmethod
272 def get_test_paths(cls, test_dir: Path, negative):
273 """Generate test paths for this operator."""
274 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
275
276 def path_params(self, path):
277 """Return a dictionary of params from the test path."""
278 params = {}
279 op_name_parts = self.name.split("_")
280 values = path.name.split("_")[len(op_name_parts) :]
281 assert len(values) == len(
282 self.wks_param_names
283 ), f"len({values}) == len({self.wks_param_names})"
284 for i, param in enumerate(self.wks_param_names):
285 params[param] = values[i]
286 return params
287
288 def get_default_params(self):
289 """Get the default parameter values from the test names."""
290 params = {param: set() for param in self.wks_param_names}
291 for path in self.test_paths:
292 path_params = self.path_params(path)
293 for k in params:
294 params[k].add(path_params[k])
295 for param in params:
296 params[param] = sorted(list(params[param]))
297 return params
298
Jeremy Johnson30476252023-11-20 16:15:30 +0000299 @staticmethod
300 def _get_test_set_paths(path):
301 """Expand a path to find all the test sets."""
302 s = 0
303 paths = []
304 # Have a bound for the maximum test sets
305 while s < 100:
306 set_path = path.with_name(f"{path.name}_s{s}")
307 if set_path.exists():
308 paths.append(set_path)
309 else:
310 if s == 0:
311 logger.error(f"Could not find test set 0 - {str(set_path)}")
312 break
313 s += 1
314 return paths
315
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100316 def select_tests(self): # noqa: C901 (function too complex)
317 """Generate the paths to the selected tests for this operator."""
318 if not self.test_paths:
319 # Exit early when nothing to select from
320 return
321
322 # the test paths that have not been selected yet
323 unused_paths = set(self.test_paths)
324
325 # a list of dictionaries of unused preselected parameter combinations
326 unused_preselected = [x for x in self.preselected]
327 logger.debug(f"preselected: {unused_preselected}")
328
329 # a list of dictionaries of unused permuted parameter combinations
330 permutes = {k: self.params[k] for k in self.permutes}
331 others = {k: self.params[k] for k in self.non_permutes}
332 unused_permuted = [x for x in expand_params(permutes, others)]
333 logger.debug(f"permuted: {unused_permuted}")
334
335 # a dictionary of sets of unused parameter values
336 if self.negative:
337 # We only care about selecting a test for each errorif case
338 unused_values = {k: set() for k in self.params}
339 unused_values["case"] = set(self.params["case"])
340 else:
341 unused_values = {k: set(v) for k, v in self.params.items()}
342
343 # select tests matching permuted, or preselected, parameter combinations
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000344 for n, path in enumerate(self.test_paths):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100345 path_params = self.path_params(path)
346 if path_params in unused_permuted or path_params in unused_preselected:
347 unused_paths.remove(path)
348 if path_params in unused_preselected:
349 unused_preselected.remove(path_params)
350 if path_params in unused_permuted:
351 unused_permuted.remove(path_params)
352 if self.negative:
353 # remove any other errorif cases, so we only match one
354 for p in list(unused_permuted):
355 if p["case"] == path_params["case"]:
356 unused_permuted.remove(p)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000357 if self.full_sparsity:
358 # Test for sparsity
359 skip = False
360 for k in path_params:
361 if k in self.sparsity and n % self.sparsity[k] != 0:
362 logger.debug(f"Skipping due to {k} sparsity - {path.name}")
363 skip = True
364 break
365 if skip:
366 continue
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100367 # remove the param values used by this path
368 for k in path_params:
369 unused_values[k].discard(path_params[k])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000370 logger.debug(f"FOUND wanted: {path.name}")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100371 if path.exists():
372 yield path
373 else:
Jeremy Johnson30476252023-11-20 16:15:30 +0000374 # Must be a test set - expand to all test sets
375 for p in Operator._get_test_set_paths(path):
376 yield p
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100377
378 # search for tests that match any unused parameter values
379 for n, path in enumerate(sorted(list(unused_paths))):
380 path_params = self.path_params(path)
381 # select paths with unused param values
382 # skipping some, if sparsity is set for the param
383 for k in path_params:
384 if path_params[k] in unused_values[k] and (
385 k not in self.sparsity or n % self.sparsity[k] == 0
386 ):
387 # remove the param values used by this path
388 for p in path_params:
389 unused_values[p].discard(path_params[p])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000390 sparsity = self.sparsity[k] if k in self.sparsity else 0
391 logger.debug(f"FOUND unused [{k}/{n}/{sparsity}]: {path.name}")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100392 if path.exists():
393 yield path
394 else:
Jeremy Johnson30476252023-11-20 16:15:30 +0000395 # Must be a test set - expand to all test sets
396 for p in Operator._get_test_set_paths(path):
397 yield p
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100398 break
399
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000400 if not self.ignore_missing:
401 # report any preselected combinations that were not found
402 for params in unused_preselected:
403 logger.warning(f"MISSING preselected: {params}")
404 # report any permuted combinations that were not found
405 for params in unused_permuted:
406 logger.debug(f"MISSING permutation: {params}")
407 # report any param values that were not found
408 for k, values in unused_values.items():
409 if values:
410 if k not in self.sparsity:
411 logger.warning(
412 f"MISSING {len(values)} values for {k}: {values}"
413 )
414 else:
415 logger.info(
416 f"Skipped {len(values)} values for {k} due to sparsity setting"
417 )
418 logger.debug(f"Values skipped: {values}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100419
420
421class AbsOperator(Operator):
422 """Test selector for the ABS operator."""
423
424 name = "abs"
425
426
427class ArithmeticRightShiftOperator(Operator):
428 """Test selector for the Arithmetic Right Shift operator."""
429
430 name = "arithmetic_right_shift"
431 param_names = ["shape", "type", "rounding"]
432
433
434class AddOperator(Operator):
435 """Test selector for the ADD operator."""
436
437 name = "add"
438
439
Won Jeon74342e52024-01-09 00:34:40 +0000440class AddShapeOperator(Operator):
441 """Test selector for the ADD_SHAPE operator."""
442
443 name = "add_shape"
444
445
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100446class ArgmaxOperator(Operator):
447 """Test selector for the ARGMAX operator."""
448
449 name = "argmax"
450 param_names = ["shape", "type", "axis"]
451
452
453class AvgPool2dOperator(Operator):
454 """Test selector for the AVG_POOL2D operator."""
455
456 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100457 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100458
459
460class BitwiseAndOperator(Operator):
461 """Test selector for the BITWISE_AND operator."""
462
463 name = "bitwise_and"
464
465
466class BitwiseNotOperator(Operator):
467 """Test selector for the BITWISE_NOT operator."""
468
469 name = "bitwise_not"
470
471
472class BitwiseOrOperator(Operator):
473 """Test selector for the BITWISE_OR operator."""
474
475 name = "bitwise_or"
476
477
478class BitwiseXorOperator(Operator):
479 """Test selector for the BITWISE_XOR operator."""
480
481 name = "bitwise_xor"
482
483
484class CastOperator(Operator):
485 """Test selector for the CAST operator."""
486
487 name = "cast"
488 param_names = ["shape", "type", "output_type"]
489
490
James Ward71616fe2022-11-23 11:00:47 +0000491class CeilOperator(Operator):
492 """Test selector for the CEIL operator."""
493
494 name = "ceil"
495
496
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100497class ClampOperator(Operator):
498 """Test selector for the CLAMP operator."""
499
500 name = "clamp"
501
502
503class CLZOperator(Operator):
504 """Test selector for the CLZ operator."""
505
506 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100507
508
509class ConcatOperator(Operator):
510 """Test selector for the CONCAT operator."""
511
512 name = "concat"
513 param_names = ["shape", "type", "axis"]
514
515
Won Jeon74342e52024-01-09 00:34:40 +0000516class ConcatShapeOperator(Operator):
517 """Test selector for the CONCAT_SHAPE operator."""
518
519 name = "concat_shape"
520
521
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100522class CondIfOperator(Operator):
523 """Test selector for the COND_IF operator."""
524
525 name = "cond_if"
526 param_names = ["variant", "shape", "type", "cond"]
527
528
529class ConstOperator(Operator):
530 """Test selector for the CONST operator."""
531
532 name = "const"
533
534
Won Jeon74342e52024-01-09 00:34:40 +0000535class ConstShapeOperator(Operator):
536 """Test selector for the CONST_SHAPE operator."""
537
538 name = "const_shape"
539
540
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100541class Conv2dOperator(Operator):
542 """Test selector for the CONV2D operator."""
543
544 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100545 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100546
547
548class Conv3dOperator(Operator):
549 """Test selector for the CONV3D operator."""
550
551 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100552 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100553
554
555class DepthwiseConv2dOperator(Operator):
556 """Test selector for the DEPTHWISE_CONV2D operator."""
557
558 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100559 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100560
561
Won Jeona21b2e82023-08-10 10:33:01 +0000562class DimOeprator(Operator):
563 """Test selector for the DIM operator."""
564
565 name = "dim"
566 param_names = ["shape", "type", "axis"]
567
568
Won Jeon74342e52024-01-09 00:34:40 +0000569class DivShapeOperator(Operator):
570 """Test selector for the DIV_SHAPE operator."""
571
572 name = "div_shape"
573
574
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100575class EqualOperator(Operator):
576 """Test selector for the EQUAL operator."""
577
578 name = "equal"
579
580
Jeremy Johnson35396f22023-01-04 17:05:25 +0000581class ExpOperator(Operator):
582 """Test selector for the EXP operator."""
583
584 name = "exp"
585
586
Won Jeon78155c62023-06-10 00:20:04 +0000587class ErfOperator(Operator):
588 """Test selector for the ERF operator."""
589
590 name = "erf"
591
592
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000593class FFT2DOperator(Operator):
594 """Test selector for the FFT2D operator."""
595
596 name = "fft2d"
597 param_names = ["shape", "type", "inverse"]
598
599
James Ward71616fe2022-11-23 11:00:47 +0000600class FloorOperator(Operator):
601 """Test selector for the FLOOR operator."""
602
603 name = "floor"
604
605
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100606class FullyConnectedOperator(Operator):
607 """Test selector for the FULLY_CONNECTED operator."""
608
609 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100610 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100611
612
613class GatherOperator(Operator):
614 """Test selector for the GATHER operator."""
615
616 name = "gather"
617
618
619class GreaterOperator(Operator):
620 """Test selector for the GREATER operator."""
621
622 name = "greater"
623
624 @classmethod
625 def get_test_paths(cls, test_dir: Path, negative):
626 """Generate test paths for this operator."""
627 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
628
629
630class GreaterEqualOperator(Operator):
631 """Test selector for the GREATER_EQUAL operator."""
632
633 name = "greater_equal"
634
635
636class IdentityOperator(Operator):
637 """Test selector for the IDENTITY operator."""
638
639 name = "identity"
640
641
642class IntDivOperator(Operator):
Jeremy Johnson35396f22023-01-04 17:05:25 +0000643 """Test selector for the INTDIV operator."""
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100644
645 name = "intdiv"
646
647
Jeremy Johnson35396f22023-01-04 17:05:25 +0000648class LogOperator(Operator):
649 """Test selector for the LOG operator."""
650
651 name = "log"
652
653
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100654class LogicalAndOperator(Operator):
655 """Test selector for the LOGICAL_AND operator."""
656
657 name = "logical_and"
658
659
660class LogicalLeftShiftOperator(Operator):
661 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
662
663 name = "logical_left_shift"
664
665
666class LogicalNotOperator(Operator):
667 """Test selector for the LOGICAL_NOT operator."""
668
669 name = "logical_not"
670
671
672class LogicalOrOperator(Operator):
673 """Test selector for the LOGICAL_OR operator."""
674
675 name = "logical_or"
676
677
678class LogicalRightShiftOperator(Operator):
679 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
680
681 name = "logical_right_shift"
682
683
684class LogicalXorOperator(Operator):
685 """Test selector for the LOGICAL_XOR operator."""
686
687 name = "logical_xor"
688
689
690class MatmulOperator(Operator):
691 """Test selector for the MATMUL operator."""
692
693 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100694 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100695
696
697class MaximumOperator(Operator):
698 """Test selector for the Maximum operator."""
699
700 name = "maximum"
701
702
703class MaxPool2dOperator(Operator):
704 """Test selector for the MAX_POOL2D operator."""
705
706 name = "max_pool2d"
707 param_names = ["shape", "type", "stride", "kernel", "pad"]
708
709
710class MinimumOperator(Operator):
711 """Test selector for the Minimum operator."""
712
713 name = "minimum"
714
715
716class MulOperator(Operator):
717 """Test selector for the MUL operator."""
718
719 name = "mul"
720 param_names = ["shape", "type", "perm", "shift"]
721
722
Won Jeon74342e52024-01-09 00:34:40 +0000723class MulShapeOperator(Operator):
724 """Test selector for the MUL_SHAPE operator."""
725
726 name = "mul_shape"
727
728
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100729class NegateOperator(Operator):
730 """Test selector for the Negate operator."""
731
732 name = "negate"
733
734
735class PadOperator(Operator):
736 """Test selector for the PAD operator."""
737
738 name = "pad"
739 param_names = ["shape", "type", "pad"]
740
741
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000742class PowOperator(Operator):
743 """Test selector for the POW operator."""
744
745 name = "pow"
746
747
Jeremy Johnson35396f22023-01-04 17:05:25 +0000748class ReciprocalOperator(Operator):
749 """Test selector for the RECIPROCAL operator."""
750
751 name = "reciprocal"
752
753
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100754class ReduceAllOperator(Operator):
755 """Test selector for the REDUCE_ALL operator."""
756
757 name = "reduce_all"
758 param_names = ["shape", "type", "axis"]
759
760
761class ReduceAnyOperator(Operator):
762 """Test selector for the REDUCE_ANY operator."""
763
764 name = "reduce_any"
765 param_names = ["shape", "type", "axis"]
766
767
768class ReduceMaxOperator(Operator):
769 """Test selector for the REDUCE_MAX operator."""
770
771 name = "reduce_max"
772 param_names = ["shape", "type", "axis"]
773
774
775class ReduceMinOperator(Operator):
776 """Test selector for the REDUCE_MIN operator."""
777
778 name = "reduce_min"
779 param_names = ["shape", "type", "axis"]
780
781
James Ward512c1ca2023-01-27 18:46:44 +0000782class ReduceProductOperator(Operator):
783 """Test selector for the REDUCE_PRODUCT operator."""
784
785 name = "reduce_product"
786 param_names = ["shape", "type", "axis"]
787
788
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100789class ReduceSumOperator(Operator):
790 """Test selector for the REDUCE_SUM operator."""
791
792 name = "reduce_sum"
793 param_names = ["shape", "type", "axis"]
794
795
796class RescaleOperator(Operator):
797 """Test selector for the RESCALE operator."""
798
799 name = "rescale"
800 param_names = [
801 "shape",
802 "type",
803 "output_type",
804 "scale",
805 "double_round",
806 "per_channel",
807 ]
808
809
810class ReshapeOperator(Operator):
811 """Test selector for the RESHAPE operator."""
812
813 name = "reshape"
Jerry Ge264f7fa2023-04-21 22:49:57 +0000814 param_names = ["shape", "type", "perm", "rank", "out"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100815
816
817class ResizeOperator(Operator):
818 """Test selector for the RESIZE operator."""
819
820 name = "resize"
821 param_names = [
822 "shape",
823 "type",
824 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100825 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100826 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100827 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100828 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100829 ]
830
831
832class ReverseOperator(Operator):
833 """Test selector for the REVERSE operator."""
834
835 name = "reverse"
836 param_names = ["shape", "type", "axis"]
837
838
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000839class RFFT2DOperator(Operator):
840 """Test selector for the RFFT2D operator."""
841
842 name = "rfft2d"
843
844
Jeremy Johnson35396f22023-01-04 17:05:25 +0000845class RsqrtOperator(Operator):
846 """Test selector for the RSQRT operator."""
847
848 name = "rsqrt"
849
850
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100851class ScatterOperator(Operator):
852 """Test selector for the SCATTER operator."""
853
854 name = "scatter"
855
856
857class SelectOperator(Operator):
858 """Test selector for the SELECT operator."""
859
860 name = "select"
861
862
James Wardb45db9a2022-12-12 13:02:44 +0000863class SigmoidOperator(Operator):
864 """Test selector for the SIGMOID operator."""
865
866 name = "sigmoid"
867
868
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100869class SliceOperator(Operator):
870 """Test selector for the SLICE operator."""
871
872 name = "slice"
873 param_names = ["shape", "type", "perm"]
874
875
876class SubOperator(Operator):
877 """Test selector for the SUB operator."""
878
879 name = "sub"
880
881
Won Jeon74342e52024-01-09 00:34:40 +0000882class SubShapeOperator(Operator):
883 """Test selector for the SUB_SHAPE operator."""
884
885 name = "sub_shape"
886
887
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100888class TableOperator(Operator):
889 """Test selector for the TABLE operator."""
890
891 name = "table"
892
893
James Wardb45db9a2022-12-12 13:02:44 +0000894class TanhOperator(Operator):
895 """Test selector for the TANH operator."""
896
897 name = "tanh"
898
899
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100900class TileOperator(Operator):
901 """Test selector for the TILE operator."""
902
903 name = "tile"
904 param_names = ["shape", "type", "perm"]
905
906
907class TransposeOperator(Operator):
908 """Test selector for the TRANSPOSE operator."""
909
910 name = "transpose"
911 param_names = ["shape", "type", "perm"]
912
913 @classmethod
914 def get_test_paths(cls, test_dir: Path, negative):
915 """Generate test paths for this operator."""
916 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
917
918
919class TransposeConv2dOperator(Operator):
920 """Test selector for the TRANSPOSE_CONV2D operator."""
921
922 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100923 param_names = [
924 "kernel",
925 "shape",
926 "type",
927 "accum_type",
928 "stride",
929 "pad",
930 "out_shape",
931 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100932
933 def path_params(self, path):
934 """Return a dictionary of params from the test path."""
935 params = super().path_params(path)
936 # out_shape is different for every test case, so ignore it for selection
937 params["out_shape"] = ""
938 return params
939
940
941class WhileLoopOperator(Operator):
942 """Test selector for the WHILE_LOOP operator."""
943
944 name = "while_loop"
945 param_names = ["shape", "type", "cond"]
946
947
948def parse_args():
949 """Parse the arguments."""
950 parser = argparse.ArgumentParser()
951 parser.add_argument(
952 "--test-dir",
953 default=Path.cwd(),
954 type=Path,
955 help=(
956 "The directory where test subdirectories for all operators can be found"
957 " (default: current working directory)"
958 ),
959 )
960 parser.add_argument(
961 "--config",
962 default=Path(__file__).with_suffix(".json"),
963 type=Path,
964 help="A JSON file defining the parameters to use for each operator",
965 )
966 parser.add_argument(
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000967 "--selector",
968 default="default",
969 type=str,
970 help="The selector in the selection dictionary to use for each operator",
971 )
972 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100973 "--full-path", action="store_true", help="output the full path for each test"
974 )
975 parser.add_argument(
976 "-v",
977 dest="verbosity",
978 action="count",
979 default=0,
980 help="Verbosity (can be used multiple times for more details)",
981 )
982 parser.add_argument(
983 "operators",
984 type=str,
985 nargs="*",
986 help=(
987 f"Select tests for the specified operator(s)"
988 f" - all operators are assumed if none are specified)"
989 f" - choose from: {[n for n in Operator.registry]}"
990 ),
991 )
992 parser.add_argument(
993 "--test-type",
994 dest="test_type",
995 choices=["positive", "negative"],
996 default="positive",
997 type=str,
998 help="type of tests selected, positive or negative",
999 )
1000 return parser.parse_args()
1001
1002
1003def main():
1004 """Example test selection."""
1005 args = parse_args()
1006
1007 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +00001008 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001009 logger.info(f"{__file__}: args: {args}")
1010
1011 try:
1012 with open(args.config, "r") as fd:
1013 config = json.load(fd)
1014 except Exception as e:
1015 logger.error(f"Config file error: {e}")
1016 return 2
1017
1018 negative = args.test_type == "negative"
1019 for op_name in Operator.registry:
1020 if not args.operators or op_name in args.operators:
1021 op_params = config[op_name] if op_name in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001022 if "selection" in op_params and args.selector in op_params["selection"]:
1023 selection_config = op_params["selection"][args.selector]
1024 else:
1025 logger.warning(
1026 f"Could not find selection config {args.selector} for {op_name}"
1027 )
1028 selection_config = {}
1029 op = Operator.registry[op_name](args.test_dir, selection_config, negative)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001030 for test_path in op.select_tests():
1031 print(test_path.resolve() if args.full_path else test_path.name)
1032
1033 return 0
1034
1035
1036if __name__ == "__main__":
1037 exit(main())