blob: 2b8e7d2e2af7541718feeb868263eccc42be46c5 [file] [log] [blame]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
8from pathlib import Path
9from typing import Any
10from typing import Dict
11from typing import List
12
13logging.basicConfig()
14logger = logging.getLogger("test_select")
15
16
17def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
18 """Generate permuted combinations of a dictionary of values and combine with others.
19
20 permutes: a dictionary with sequences of values to be fully permuted
21 others: a dictionary with sequences of values not fully permuted, but all used
22
23 This yields dictionaries with one value from each of the items in permutes,
24 combined with one value from each of the items in others.
25
26 Example 1:
27
28 permutes = {"a": [1, 2], "b": [3, 4]}
29 others = {"c": [5, 6, 7], "d" [True, False]}
30
31 generates:
32
33 [
34 {"a": 1, "b": 3, "c": 5, "d": True},
35 {"a": 1, "b": 4, "c": 6, "d": False`},
36 {"a": 2, "b": 3, "c": 7, "d": True},
37 {"a": 2, "b": 4, "c": 5, "d": False`},
38 ]
39
40 Example 2:
41
42 permutes = {"a": [1, 2], "b": [3, 4]}
43 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
44
45 generates:
46
47 [
48 {"a": 1, "b": 3, "c": 5, "d": True},
49 {"a": 1, "b": 4, "c": 6, "d": False},
50 {"a": 2, "b": 3, "c": 7, "d": True},
51 {"a": 2, "b": 4, "c": 8, "d": False},
52 {"a": 1, "b": 3, "c": 9, "d": True},
53 ]
54
55 Raises:
56 ValueError if any item is in both permutes and others
57 """
58 for k in permutes:
59 if k in others:
60 raise ValueError(f"item conflict: {k}")
61
62 p_keys = []
63 p_vals = []
64 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
65 # so we return a single, empty dictionary, if others is also empty
66 p_product_len = 1
67 # extract the keys and values from the permutes dictionary
68 # and calulate the product of the number of values in each item as we do so
69 for k, v in permutes.items():
70 p_keys.append(k)
71 p_vals.append(v)
72 p_product_len *= len(v)
73 # create a cyclic generator for the product of all the permuted values
74 p_product = itertools.product(*p_vals)
75 p_generator = itertools.cycle(p_product)
76
77 o_keys = []
78 o_vals = []
79 o_generators = []
80 # extract the keys and values from the others dictionary
81 # and create a cyclic generator for each list of values
82 for k, v in others.items():
83 o_keys.append(k)
84 o_vals.append(v)
85 o_generators.append(itertools.cycle(v))
86
87 # The number of params dictionaries generated will be the maximumum size
88 # of the permuted values and the non-permuted values from others
89 max_items = max([p_product_len] + [len(x) for x in o_vals])
90
91 # create a dictionary with a single value for each of the permutes and others keys
92 for _ in range(max_items):
93 params = {}
94 # add the values for the permutes parameters
95 # the permuted values generator returns a value for each of the permuted keys
96 # in the same order as they were originally given
97 p_vals = next(p_generator)
98 for i in range(len(p_keys)):
99 params[p_keys[i]] = p_vals[i]
100 # add the values for the others parameters
101 # there is a separate generator for each of the others values
102 for i in range(len(o_keys)):
103 params[o_keys[i]] = next(o_generators[i])
104 yield params
105
106
107class Operator:
108 """Base class for operator specific selection properties."""
109
110 # A registry of all Operator subclasses, indexed by the operator name
111 registry = {}
112
113 def __init_subclass__(cls, **kwargs):
114 """Subclass initialiser to register all Operator classes."""
115 super().__init_subclass__(**kwargs)
116 cls.registry[cls.name] = cls
117
118 # Derived classes must override the operator name
119 name = None
120 # Operators with additional parameters must override the param_names
121 # NB: the order must match the order the values appear in the test names
122 param_names = ["shape", "type"]
123
124 # Working set of param_names - updated for negative tests
125 wks_param_names = None
126
127 def __init__(
128 self,
129 test_dir: Path,
130 config: Dict[str, Dict[str, List[Any]]],
131 negative=False,
132 exclude_types=None,
133 ):
134 """Initialise the selection parameters for an operator.
135
136 test_dir: the directory where the tests for all operators can be found
137 config: a dictionary with:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100138 "params" - a dictionary with mappings of parameter names to the values
139 to select (a sub-set of expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100140 "permutes" - a list of parameter names to be permuted
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100141 "preselected" - a list of dictionaries containing parameter names and
142 pre-chosen values
143 "sparsity" - a dictionary of parameter names with a sparsity value
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100144 "errorifs" - list of ERRORIF case names to be selected (negative test)
145 negative: bool indicating if negative testing is being selected (ERRORIF tests)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100146
147 EXAMPLE CONFIG:
148 "params": {
149 "output_type": [
150 "outi8",
151 "outb"
152 ]
153 },
154 "permutes": [
155 "shape",
156 "type"
157 ],
158 "sparsity": {
159 "pad": 15
160 },
161 "preselected": [
162 {
163 "shape": "6",
164 "type": "i8",
165 "pad": "pad00"
166 }
167 ],
168 "errorifs": [
169 "InputZeroPointNotZero"
170 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100171 """
172 assert isinstance(
173 self.name, str
174 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
175
176 self.negative = negative
177 self.wks_param_names = self.param_names.copy()
178 if self.negative:
179 # need to override positive set up - use "errorifs" config if set
180 # add in errorif case before shape to support all ops, including
181 # different ops like COND_IF and CONVnD etc
182 index = self.wks_param_names.index("shape")
183 self.wks_param_names[index:index] = ["ERRORIF", "case"]
184 config["params"] = {x: [] for x in self.wks_param_names}
185 config["params"]["case"] = (
186 config["errorifs"] if "errorifs" in config else []
187 )
188 config["permutes"] = []
189 config["preselected"] = {}
190
191 self.params = config["params"] if "params" in config else {}
192 self.permutes = config["permutes"] if "permutes" in config else []
193 self.sparsity = config["sparsity"] if "sparsity" in config else {}
194 self.preselected = config["preselected"] if "preselected" in config else {}
195 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
196 logger.info(f"{self.name}: permutes={self.permutes}")
197 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
198
199 if exclude_types is None:
200 exclude_types = []
201 self.test_paths = [
202 p
203 for p in self.get_test_paths(test_dir, self.negative)
204 # exclusion of types if requested
205 if self.path_params(p)["type"] not in exclude_types
206 ]
207 if not self.test_paths:
208 logger.error(f"no tests found for {self.name} in {test_dir}")
209 logger.debug(f"{self.name}: paths={self.test_paths}")
210
211 # get default parameter values for any not given in the config
212 default_params = self.get_default_params()
213 for param in default_params:
214 if param not in self.params or not self.params[param]:
215 self.params[param] = default_params[param]
216 for param in self.wks_param_names:
217 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
218
219 @staticmethod
220 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
221 """Generate test paths for operators using operator specifics."""
222 for base_dir in sorted(test_dir.glob(base_dir_glob)):
223 for path in sorted(base_dir.glob(path_glob)):
224 if (not negative and "ERRORIF" not in str(path)) or (
225 negative and "ERRORIF" in str(path)
226 ):
227 yield path
228
229 @classmethod
230 def get_test_paths(cls, test_dir: Path, negative):
231 """Generate test paths for this operator."""
232 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
233
234 def path_params(self, path):
235 """Return a dictionary of params from the test path."""
236 params = {}
237 op_name_parts = self.name.split("_")
238 values = path.name.split("_")[len(op_name_parts) :]
239 assert len(values) == len(
240 self.wks_param_names
241 ), f"len({values}) == len({self.wks_param_names})"
242 for i, param in enumerate(self.wks_param_names):
243 params[param] = values[i]
244 return params
245
246 def get_default_params(self):
247 """Get the default parameter values from the test names."""
248 params = {param: set() for param in self.wks_param_names}
249 for path in self.test_paths:
250 path_params = self.path_params(path)
251 for k in params:
252 params[k].add(path_params[k])
253 for param in params:
254 params[param] = sorted(list(params[param]))
255 return params
256
257 def select_tests(self): # noqa: C901 (function too complex)
258 """Generate the paths to the selected tests for this operator."""
259 if not self.test_paths:
260 # Exit early when nothing to select from
261 return
262
263 # the test paths that have not been selected yet
264 unused_paths = set(self.test_paths)
265
266 # a list of dictionaries of unused preselected parameter combinations
267 unused_preselected = [x for x in self.preselected]
268 logger.debug(f"preselected: {unused_preselected}")
269
270 # a list of dictionaries of unused permuted parameter combinations
271 permutes = {k: self.params[k] for k in self.permutes}
272 others = {k: self.params[k] for k in self.non_permutes}
273 unused_permuted = [x for x in expand_params(permutes, others)]
274 logger.debug(f"permuted: {unused_permuted}")
275
276 # a dictionary of sets of unused parameter values
277 if self.negative:
278 # We only care about selecting a test for each errorif case
279 unused_values = {k: set() for k in self.params}
280 unused_values["case"] = set(self.params["case"])
281 else:
282 unused_values = {k: set(v) for k, v in self.params.items()}
283
284 # select tests matching permuted, or preselected, parameter combinations
285 for path in self.test_paths:
286 path_params = self.path_params(path)
287 if path_params in unused_permuted or path_params in unused_preselected:
288 unused_paths.remove(path)
289 if path_params in unused_preselected:
290 unused_preselected.remove(path_params)
291 if path_params in unused_permuted:
292 unused_permuted.remove(path_params)
293 if self.negative:
294 # remove any other errorif cases, so we only match one
295 for p in list(unused_permuted):
296 if p["case"] == path_params["case"]:
297 unused_permuted.remove(p)
298 # remove the param values used by this path
299 for k in path_params:
300 unused_values[k].discard(path_params[k])
301 logger.debug(f"FOUND: {path.name}")
302 yield path
303
304 # search for tests that match any unused parameter values
305 for n, path in enumerate(sorted(list(unused_paths))):
306 path_params = self.path_params(path)
307 # select paths with unused param values
308 # skipping some, if sparsity is set for the param
309 for k in path_params:
310 if path_params[k] in unused_values[k] and (
311 k not in self.sparsity or n % self.sparsity[k] == 0
312 ):
313 # remove the param values used by this path
314 for p in path_params:
315 unused_values[p].discard(path_params[p])
316 logger.debug(f"FOUND: {path.name}")
317 yield path
318 break
319
320 # report any preselected combinations that were not found
321 for params in unused_preselected:
322 logger.warning(f"MISSING preselected: {params}")
323 # report any permuted combinations that were not found
324 for params in unused_permuted:
325 logger.debug(f"MISSING permutation: {params}")
326 # report any param values that were not found
327 for k, values in unused_values.items():
328 if values:
329 if k not in self.sparsity:
330 logger.warning(f"MISSING {len(values)} values for {k}: {values}")
331 else:
332 logger.info(
333 f"Skipped {len(values)} values for {k} due to sparsity setting"
334 )
335 logger.debug(f"Values skipped: {values}")
336
337
338class AbsOperator(Operator):
339 """Test selector for the ABS operator."""
340
341 name = "abs"
342
343
344class ArithmeticRightShiftOperator(Operator):
345 """Test selector for the Arithmetic Right Shift operator."""
346
347 name = "arithmetic_right_shift"
348 param_names = ["shape", "type", "rounding"]
349
350
351class AddOperator(Operator):
352 """Test selector for the ADD operator."""
353
354 name = "add"
355
356
357class ArgmaxOperator(Operator):
358 """Test selector for the ARGMAX operator."""
359
360 name = "argmax"
361 param_names = ["shape", "type", "axis"]
362
363
364class AvgPool2dOperator(Operator):
365 """Test selector for the AVG_POOL2D operator."""
366
367 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100368 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100369
370
371class BitwiseAndOperator(Operator):
372 """Test selector for the BITWISE_AND operator."""
373
374 name = "bitwise_and"
375
376
377class BitwiseNotOperator(Operator):
378 """Test selector for the BITWISE_NOT operator."""
379
380 name = "bitwise_not"
381
382
383class BitwiseOrOperator(Operator):
384 """Test selector for the BITWISE_OR operator."""
385
386 name = "bitwise_or"
387
388
389class BitwiseXorOperator(Operator):
390 """Test selector for the BITWISE_XOR operator."""
391
392 name = "bitwise_xor"
393
394
395class CastOperator(Operator):
396 """Test selector for the CAST operator."""
397
398 name = "cast"
399 param_names = ["shape", "type", "output_type"]
400
401
402class ClampOperator(Operator):
403 """Test selector for the CLAMP operator."""
404
405 name = "clamp"
406
407
408class CLZOperator(Operator):
409 """Test selector for the CLZ operator."""
410
411 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100412
413
414class ConcatOperator(Operator):
415 """Test selector for the CONCAT operator."""
416
417 name = "concat"
418 param_names = ["shape", "type", "axis"]
419
420
421class CondIfOperator(Operator):
422 """Test selector for the COND_IF operator."""
423
424 name = "cond_if"
425 param_names = ["variant", "shape", "type", "cond"]
426
427
428class ConstOperator(Operator):
429 """Test selector for the CONST operator."""
430
431 name = "const"
432
433
434class Conv2dOperator(Operator):
435 """Test selector for the CONV2D operator."""
436
437 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100438 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100439
440
441class Conv3dOperator(Operator):
442 """Test selector for the CONV3D operator."""
443
444 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100445 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100446
447
448class DepthwiseConv2dOperator(Operator):
449 """Test selector for the DEPTHWISE_CONV2D operator."""
450
451 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100452 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100453
454
455class EqualOperator(Operator):
456 """Test selector for the EQUAL operator."""
457
458 name = "equal"
459
460
461class FullyConnectedOperator(Operator):
462 """Test selector for the FULLY_CONNECTED operator."""
463
464 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100465 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100466
467
468class GatherOperator(Operator):
469 """Test selector for the GATHER operator."""
470
471 name = "gather"
472
473
474class GreaterOperator(Operator):
475 """Test selector for the GREATER operator."""
476
477 name = "greater"
478
479 @classmethod
480 def get_test_paths(cls, test_dir: Path, negative):
481 """Generate test paths for this operator."""
482 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
483
484
485class GreaterEqualOperator(Operator):
486 """Test selector for the GREATER_EQUAL operator."""
487
488 name = "greater_equal"
489
490
491class IdentityOperator(Operator):
492 """Test selector for the IDENTITY operator."""
493
494 name = "identity"
495
496
497class IntDivOperator(Operator):
498 """Test selector for the INTDIV."""
499
500 name = "intdiv"
501
502
503class LogicalAndOperator(Operator):
504 """Test selector for the LOGICAL_AND operator."""
505
506 name = "logical_and"
507
508
509class LogicalLeftShiftOperator(Operator):
510 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
511
512 name = "logical_left_shift"
513
514
515class LogicalNotOperator(Operator):
516 """Test selector for the LOGICAL_NOT operator."""
517
518 name = "logical_not"
519
520
521class LogicalOrOperator(Operator):
522 """Test selector for the LOGICAL_OR operator."""
523
524 name = "logical_or"
525
526
527class LogicalRightShiftOperator(Operator):
528 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
529
530 name = "logical_right_shift"
531
532
533class LogicalXorOperator(Operator):
534 """Test selector for the LOGICAL_XOR operator."""
535
536 name = "logical_xor"
537
538
539class MatmulOperator(Operator):
540 """Test selector for the MATMUL operator."""
541
542 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100543 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100544
545
546class MaximumOperator(Operator):
547 """Test selector for the Maximum operator."""
548
549 name = "maximum"
550
551
552class MaxPool2dOperator(Operator):
553 """Test selector for the MAX_POOL2D operator."""
554
555 name = "max_pool2d"
556 param_names = ["shape", "type", "stride", "kernel", "pad"]
557
558
559class MinimumOperator(Operator):
560 """Test selector for the Minimum operator."""
561
562 name = "minimum"
563
564
565class MulOperator(Operator):
566 """Test selector for the MUL operator."""
567
568 name = "mul"
569 param_names = ["shape", "type", "perm", "shift"]
570
571
572class NegateOperator(Operator):
573 """Test selector for the Negate operator."""
574
575 name = "negate"
576
577
578class PadOperator(Operator):
579 """Test selector for the PAD operator."""
580
581 name = "pad"
582 param_names = ["shape", "type", "pad"]
583
584
585class ReduceAllOperator(Operator):
586 """Test selector for the REDUCE_ALL operator."""
587
588 name = "reduce_all"
589 param_names = ["shape", "type", "axis"]
590
591
592class ReduceAnyOperator(Operator):
593 """Test selector for the REDUCE_ANY operator."""
594
595 name = "reduce_any"
596 param_names = ["shape", "type", "axis"]
597
598
599class ReduceMaxOperator(Operator):
600 """Test selector for the REDUCE_MAX operator."""
601
602 name = "reduce_max"
603 param_names = ["shape", "type", "axis"]
604
605
606class ReduceMinOperator(Operator):
607 """Test selector for the REDUCE_MIN operator."""
608
609 name = "reduce_min"
610 param_names = ["shape", "type", "axis"]
611
612
613class ReduceSumOperator(Operator):
614 """Test selector for the REDUCE_SUM operator."""
615
616 name = "reduce_sum"
617 param_names = ["shape", "type", "axis"]
618
619
620class RescaleOperator(Operator):
621 """Test selector for the RESCALE operator."""
622
623 name = "rescale"
624 param_names = [
625 "shape",
626 "type",
627 "output_type",
628 "scale",
629 "double_round",
630 "per_channel",
631 ]
632
633
634class ReshapeOperator(Operator):
635 """Test selector for the RESHAPE operator."""
636
637 name = "reshape"
638 param_names = ["shape", "type", "perm", "rank"]
639
640
641class ResizeOperator(Operator):
642 """Test selector for the RESIZE operator."""
643
644 name = "resize"
645 param_names = [
646 "shape",
647 "type",
648 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100649 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100650 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100651 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100652 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100653 ]
654
655
656class ReverseOperator(Operator):
657 """Test selector for the REVERSE operator."""
658
659 name = "reverse"
660 param_names = ["shape", "type", "axis"]
661
662
663class ScatterOperator(Operator):
664 """Test selector for the SCATTER operator."""
665
666 name = "scatter"
667
668
669class SelectOperator(Operator):
670 """Test selector for the SELECT operator."""
671
672 name = "select"
673
674
675class SliceOperator(Operator):
676 """Test selector for the SLICE operator."""
677
678 name = "slice"
679 param_names = ["shape", "type", "perm"]
680
681
682class SubOperator(Operator):
683 """Test selector for the SUB operator."""
684
685 name = "sub"
686
687
688class TableOperator(Operator):
689 """Test selector for the TABLE operator."""
690
691 name = "table"
692
693
694class TileOperator(Operator):
695 """Test selector for the TILE operator."""
696
697 name = "tile"
698 param_names = ["shape", "type", "perm"]
699
700
701class TransposeOperator(Operator):
702 """Test selector for the TRANSPOSE operator."""
703
704 name = "transpose"
705 param_names = ["shape", "type", "perm"]
706
707 @classmethod
708 def get_test_paths(cls, test_dir: Path, negative):
709 """Generate test paths for this operator."""
710 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
711
712
713class TransposeConv2dOperator(Operator):
714 """Test selector for the TRANSPOSE_CONV2D operator."""
715
716 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100717 param_names = [
718 "kernel",
719 "shape",
720 "type",
721 "accum_type",
722 "stride",
723 "pad",
724 "out_shape",
725 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100726
727 def path_params(self, path):
728 """Return a dictionary of params from the test path."""
729 params = super().path_params(path)
730 # out_shape is different for every test case, so ignore it for selection
731 params["out_shape"] = ""
732 return params
733
734
735class WhileLoopOperator(Operator):
736 """Test selector for the WHILE_LOOP operator."""
737
738 name = "while_loop"
739 param_names = ["shape", "type", "cond"]
740
741
742def parse_args():
743 """Parse the arguments."""
744 parser = argparse.ArgumentParser()
745 parser.add_argument(
746 "--test-dir",
747 default=Path.cwd(),
748 type=Path,
749 help=(
750 "The directory where test subdirectories for all operators can be found"
751 " (default: current working directory)"
752 ),
753 )
754 parser.add_argument(
755 "--config",
756 default=Path(__file__).with_suffix(".json"),
757 type=Path,
758 help="A JSON file defining the parameters to use for each operator",
759 )
760 parser.add_argument(
761 "--full-path", action="store_true", help="output the full path for each test"
762 )
763 parser.add_argument(
764 "-v",
765 dest="verbosity",
766 action="count",
767 default=0,
768 help="Verbosity (can be used multiple times for more details)",
769 )
770 parser.add_argument(
771 "operators",
772 type=str,
773 nargs="*",
774 help=(
775 f"Select tests for the specified operator(s)"
776 f" - all operators are assumed if none are specified)"
777 f" - choose from: {[n for n in Operator.registry]}"
778 ),
779 )
780 parser.add_argument(
781 "--test-type",
782 dest="test_type",
783 choices=["positive", "negative"],
784 default="positive",
785 type=str,
786 help="type of tests selected, positive or negative",
787 )
788 return parser.parse_args()
789
790
791def main():
792 """Example test selection."""
793 args = parse_args()
794
795 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +0000796 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100797 logger.info(f"{__file__}: args: {args}")
798
799 try:
800 with open(args.config, "r") as fd:
801 config = json.load(fd)
802 except Exception as e:
803 logger.error(f"Config file error: {e}")
804 return 2
805
806 negative = args.test_type == "negative"
807 for op_name in Operator.registry:
808 if not args.operators or op_name in args.operators:
809 op_params = config[op_name] if op_name in config else {}
810 op = Operator.registry[op_name](
811 args.test_dir, op_params, negative, exclude_types=["float"]
812 )
813 for test_path in op.select_tests():
814 print(test_path.resolve() if args.full_path else test_path.name)
815
816 return 0
817
818
819if __name__ == "__main__":
820 exit(main())