blob: 59feae19687db1fa955f6bf08bfb6c871468682b [file] [log] [blame]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
8from pathlib import Path
9from typing import Any
10from typing import Dict
11from typing import List
12
13logging.basicConfig()
14logger = logging.getLogger("test_select")
15
16
17def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
18 """Generate permuted combinations of a dictionary of values and combine with others.
19
20 permutes: a dictionary with sequences of values to be fully permuted
21 others: a dictionary with sequences of values not fully permuted, but all used
22
23 This yields dictionaries with one value from each of the items in permutes,
24 combined with one value from each of the items in others.
25
26 Example 1:
27
28 permutes = {"a": [1, 2], "b": [3, 4]}
29 others = {"c": [5, 6, 7], "d" [True, False]}
30
31 generates:
32
33 [
34 {"a": 1, "b": 3, "c": 5, "d": True},
35 {"a": 1, "b": 4, "c": 6, "d": False`},
36 {"a": 2, "b": 3, "c": 7, "d": True},
37 {"a": 2, "b": 4, "c": 5, "d": False`},
38 ]
39
40 Example 2:
41
42 permutes = {"a": [1, 2], "b": [3, 4]}
43 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
44
45 generates:
46
47 [
48 {"a": 1, "b": 3, "c": 5, "d": True},
49 {"a": 1, "b": 4, "c": 6, "d": False},
50 {"a": 2, "b": 3, "c": 7, "d": True},
51 {"a": 2, "b": 4, "c": 8, "d": False},
52 {"a": 1, "b": 3, "c": 9, "d": True},
53 ]
54
55 Raises:
56 ValueError if any item is in both permutes and others
57 """
58 for k in permutes:
59 if k in others:
60 raise ValueError(f"item conflict: {k}")
61
62 p_keys = []
63 p_vals = []
64 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
65 # so we return a single, empty dictionary, if others is also empty
66 p_product_len = 1
67 # extract the keys and values from the permutes dictionary
68 # and calulate the product of the number of values in each item as we do so
69 for k, v in permutes.items():
70 p_keys.append(k)
71 p_vals.append(v)
72 p_product_len *= len(v)
73 # create a cyclic generator for the product of all the permuted values
74 p_product = itertools.product(*p_vals)
75 p_generator = itertools.cycle(p_product)
76
77 o_keys = []
78 o_vals = []
79 o_generators = []
80 # extract the keys and values from the others dictionary
81 # and create a cyclic generator for each list of values
82 for k, v in others.items():
83 o_keys.append(k)
84 o_vals.append(v)
85 o_generators.append(itertools.cycle(v))
86
87 # The number of params dictionaries generated will be the maximumum size
88 # of the permuted values and the non-permuted values from others
89 max_items = max([p_product_len] + [len(x) for x in o_vals])
90
91 # create a dictionary with a single value for each of the permutes and others keys
92 for _ in range(max_items):
93 params = {}
94 # add the values for the permutes parameters
95 # the permuted values generator returns a value for each of the permuted keys
96 # in the same order as they were originally given
97 p_vals = next(p_generator)
98 for i in range(len(p_keys)):
99 params[p_keys[i]] = p_vals[i]
100 # add the values for the others parameters
101 # there is a separate generator for each of the others values
102 for i in range(len(o_keys)):
103 params[o_keys[i]] = next(o_generators[i])
104 yield params
105
106
107class Operator:
108 """Base class for operator specific selection properties."""
109
110 # A registry of all Operator subclasses, indexed by the operator name
111 registry = {}
112
113 def __init_subclass__(cls, **kwargs):
114 """Subclass initialiser to register all Operator classes."""
115 super().__init_subclass__(**kwargs)
116 cls.registry[cls.name] = cls
117
118 # Derived classes must override the operator name
119 name = None
120 # Operators with additional parameters must override the param_names
121 # NB: the order must match the order the values appear in the test names
122 param_names = ["shape", "type"]
123
124 # Working set of param_names - updated for negative tests
125 wks_param_names = None
126
127 def __init__(
128 self,
129 test_dir: Path,
130 config: Dict[str, Dict[str, List[Any]]],
131 negative=False,
132 exclude_types=None,
133 ):
134 """Initialise the selection parameters for an operator.
135
136 test_dir: the directory where the tests for all operators can be found
137 config: a dictionary with:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100138 "params" - a dictionary with mappings of parameter names to the values
139 to select (a sub-set of expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100140 "permutes" - a list of parameter names to be permuted
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100141 "preselected" - a list of dictionaries containing parameter names and
142 pre-chosen values
143 "sparsity" - a dictionary of parameter names with a sparsity value
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100144 "errorifs" - list of ERRORIF case names to be selected (negative test)
145 negative: bool indicating if negative testing is being selected (ERRORIF tests)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100146
147 EXAMPLE CONFIG:
148 "params": {
149 "output_type": [
150 "outi8",
151 "outb"
152 ]
153 },
154 "permutes": [
155 "shape",
156 "type"
157 ],
158 "sparsity": {
159 "pad": 15
160 },
161 "preselected": [
162 {
163 "shape": "6",
164 "type": "i8",
165 "pad": "pad00"
166 }
167 ],
168 "errorifs": [
169 "InputZeroPointNotZero"
170 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100171 """
172 assert isinstance(
173 self.name, str
174 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
175
176 self.negative = negative
177 self.wks_param_names = self.param_names.copy()
178 if self.negative:
179 # need to override positive set up - use "errorifs" config if set
180 # add in errorif case before shape to support all ops, including
181 # different ops like COND_IF and CONVnD etc
182 index = self.wks_param_names.index("shape")
183 self.wks_param_names[index:index] = ["ERRORIF", "case"]
184 config["params"] = {x: [] for x in self.wks_param_names}
185 config["params"]["case"] = (
186 config["errorifs"] if "errorifs" in config else []
187 )
188 config["permutes"] = []
189 config["preselected"] = {}
190
191 self.params = config["params"] if "params" in config else {}
192 self.permutes = config["permutes"] if "permutes" in config else []
193 self.sparsity = config["sparsity"] if "sparsity" in config else {}
194 self.preselected = config["preselected"] if "preselected" in config else {}
195 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
196 logger.info(f"{self.name}: permutes={self.permutes}")
197 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
198
199 if exclude_types is None:
200 exclude_types = []
201 self.test_paths = [
202 p
203 for p in self.get_test_paths(test_dir, self.negative)
204 # exclusion of types if requested
205 if self.path_params(p)["type"] not in exclude_types
206 ]
207 if not self.test_paths:
208 logger.error(f"no tests found for {self.name} in {test_dir}")
209 logger.debug(f"{self.name}: paths={self.test_paths}")
210
211 # get default parameter values for any not given in the config
212 default_params = self.get_default_params()
213 for param in default_params:
214 if param not in self.params or not self.params[param]:
215 self.params[param] = default_params[param]
216 for param in self.wks_param_names:
217 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
218
219 @staticmethod
220 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
221 """Generate test paths for operators using operator specifics."""
222 for base_dir in sorted(test_dir.glob(base_dir_glob)):
223 for path in sorted(base_dir.glob(path_glob)):
224 if (not negative and "ERRORIF" not in str(path)) or (
225 negative and "ERRORIF" in str(path)
226 ):
227 yield path
228
229 @classmethod
230 def get_test_paths(cls, test_dir: Path, negative):
231 """Generate test paths for this operator."""
232 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
233
234 def path_params(self, path):
235 """Return a dictionary of params from the test path."""
236 params = {}
237 op_name_parts = self.name.split("_")
238 values = path.name.split("_")[len(op_name_parts) :]
239 assert len(values) == len(
240 self.wks_param_names
241 ), f"len({values}) == len({self.wks_param_names})"
242 for i, param in enumerate(self.wks_param_names):
243 params[param] = values[i]
244 return params
245
246 def get_default_params(self):
247 """Get the default parameter values from the test names."""
248 params = {param: set() for param in self.wks_param_names}
249 for path in self.test_paths:
250 path_params = self.path_params(path)
251 for k in params:
252 params[k].add(path_params[k])
253 for param in params:
254 params[param] = sorted(list(params[param]))
255 return params
256
257 def select_tests(self): # noqa: C901 (function too complex)
258 """Generate the paths to the selected tests for this operator."""
259 if not self.test_paths:
260 # Exit early when nothing to select from
261 return
262
263 # the test paths that have not been selected yet
264 unused_paths = set(self.test_paths)
265
266 # a list of dictionaries of unused preselected parameter combinations
267 unused_preselected = [x for x in self.preselected]
268 logger.debug(f"preselected: {unused_preselected}")
269
270 # a list of dictionaries of unused permuted parameter combinations
271 permutes = {k: self.params[k] for k in self.permutes}
272 others = {k: self.params[k] for k in self.non_permutes}
273 unused_permuted = [x for x in expand_params(permutes, others)]
274 logger.debug(f"permuted: {unused_permuted}")
275
276 # a dictionary of sets of unused parameter values
277 if self.negative:
278 # We only care about selecting a test for each errorif case
279 unused_values = {k: set() for k in self.params}
280 unused_values["case"] = set(self.params["case"])
281 else:
282 unused_values = {k: set(v) for k, v in self.params.items()}
283
284 # select tests matching permuted, or preselected, parameter combinations
285 for path in self.test_paths:
286 path_params = self.path_params(path)
287 if path_params in unused_permuted or path_params in unused_preselected:
288 unused_paths.remove(path)
289 if path_params in unused_preselected:
290 unused_preselected.remove(path_params)
291 if path_params in unused_permuted:
292 unused_permuted.remove(path_params)
293 if self.negative:
294 # remove any other errorif cases, so we only match one
295 for p in list(unused_permuted):
296 if p["case"] == path_params["case"]:
297 unused_permuted.remove(p)
298 # remove the param values used by this path
299 for k in path_params:
300 unused_values[k].discard(path_params[k])
301 logger.debug(f"FOUND: {path.name}")
302 yield path
303
304 # search for tests that match any unused parameter values
305 for n, path in enumerate(sorted(list(unused_paths))):
306 path_params = self.path_params(path)
307 # select paths with unused param values
308 # skipping some, if sparsity is set for the param
309 for k in path_params:
310 if path_params[k] in unused_values[k] and (
311 k not in self.sparsity or n % self.sparsity[k] == 0
312 ):
313 # remove the param values used by this path
314 for p in path_params:
315 unused_values[p].discard(path_params[p])
316 logger.debug(f"FOUND: {path.name}")
317 yield path
318 break
319
320 # report any preselected combinations that were not found
321 for params in unused_preselected:
322 logger.warning(f"MISSING preselected: {params}")
323 # report any permuted combinations that were not found
324 for params in unused_permuted:
325 logger.debug(f"MISSING permutation: {params}")
326 # report any param values that were not found
327 for k, values in unused_values.items():
328 if values:
329 if k not in self.sparsity:
330 logger.warning(f"MISSING {len(values)} values for {k}: {values}")
331 else:
332 logger.info(
333 f"Skipped {len(values)} values for {k} due to sparsity setting"
334 )
335 logger.debug(f"Values skipped: {values}")
336
337
338class AbsOperator(Operator):
339 """Test selector for the ABS operator."""
340
341 name = "abs"
342
343
344class ArithmeticRightShiftOperator(Operator):
345 """Test selector for the Arithmetic Right Shift operator."""
346
347 name = "arithmetic_right_shift"
348 param_names = ["shape", "type", "rounding"]
349
350
351class AddOperator(Operator):
352 """Test selector for the ADD operator."""
353
354 name = "add"
355
356
357class ArgmaxOperator(Operator):
358 """Test selector for the ARGMAX operator."""
359
360 name = "argmax"
361 param_names = ["shape", "type", "axis"]
362
363
364class AvgPool2dOperator(Operator):
365 """Test selector for the AVG_POOL2D operator."""
366
367 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100368 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100369
370
371class BitwiseAndOperator(Operator):
372 """Test selector for the BITWISE_AND operator."""
373
374 name = "bitwise_and"
375
376
377class BitwiseNotOperator(Operator):
378 """Test selector for the BITWISE_NOT operator."""
379
380 name = "bitwise_not"
381
382
383class BitwiseOrOperator(Operator):
384 """Test selector for the BITWISE_OR operator."""
385
386 name = "bitwise_or"
387
388
389class BitwiseXorOperator(Operator):
390 """Test selector for the BITWISE_XOR operator."""
391
392 name = "bitwise_xor"
393
394
395class CastOperator(Operator):
396 """Test selector for the CAST operator."""
397
398 name = "cast"
399 param_names = ["shape", "type", "output_type"]
400
401
James Ward71616fe2022-11-23 11:00:47 +0000402class CeilOperator(Operator):
403 """Test selector for the CEIL operator."""
404
405 name = "ceil"
406
407
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100408class ClampOperator(Operator):
409 """Test selector for the CLAMP operator."""
410
411 name = "clamp"
412
413
414class CLZOperator(Operator):
415 """Test selector for the CLZ operator."""
416
417 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100418
419
420class ConcatOperator(Operator):
421 """Test selector for the CONCAT operator."""
422
423 name = "concat"
424 param_names = ["shape", "type", "axis"]
425
426
427class CondIfOperator(Operator):
428 """Test selector for the COND_IF operator."""
429
430 name = "cond_if"
431 param_names = ["variant", "shape", "type", "cond"]
432
433
434class ConstOperator(Operator):
435 """Test selector for the CONST operator."""
436
437 name = "const"
438
439
440class Conv2dOperator(Operator):
441 """Test selector for the CONV2D operator."""
442
443 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100444 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100445
446
447class Conv3dOperator(Operator):
448 """Test selector for the CONV3D operator."""
449
450 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100451 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100452
453
454class DepthwiseConv2dOperator(Operator):
455 """Test selector for the DEPTHWISE_CONV2D operator."""
456
457 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100458 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100459
460
461class EqualOperator(Operator):
462 """Test selector for the EQUAL operator."""
463
464 name = "equal"
465
466
James Ward71616fe2022-11-23 11:00:47 +0000467class FloorOperator(Operator):
468 """Test selector for the FLOOR operator."""
469
470 name = "floor"
471
472
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100473class FullyConnectedOperator(Operator):
474 """Test selector for the FULLY_CONNECTED operator."""
475
476 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100477 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100478
479
480class GatherOperator(Operator):
481 """Test selector for the GATHER operator."""
482
483 name = "gather"
484
485
486class GreaterOperator(Operator):
487 """Test selector for the GREATER operator."""
488
489 name = "greater"
490
491 @classmethod
492 def get_test_paths(cls, test_dir: Path, negative):
493 """Generate test paths for this operator."""
494 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
495
496
497class GreaterEqualOperator(Operator):
498 """Test selector for the GREATER_EQUAL operator."""
499
500 name = "greater_equal"
501
502
503class IdentityOperator(Operator):
504 """Test selector for the IDENTITY operator."""
505
506 name = "identity"
507
508
509class IntDivOperator(Operator):
510 """Test selector for the INTDIV."""
511
512 name = "intdiv"
513
514
515class LogicalAndOperator(Operator):
516 """Test selector for the LOGICAL_AND operator."""
517
518 name = "logical_and"
519
520
521class LogicalLeftShiftOperator(Operator):
522 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
523
524 name = "logical_left_shift"
525
526
527class LogicalNotOperator(Operator):
528 """Test selector for the LOGICAL_NOT operator."""
529
530 name = "logical_not"
531
532
533class LogicalOrOperator(Operator):
534 """Test selector for the LOGICAL_OR operator."""
535
536 name = "logical_or"
537
538
539class LogicalRightShiftOperator(Operator):
540 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
541
542 name = "logical_right_shift"
543
544
545class LogicalXorOperator(Operator):
546 """Test selector for the LOGICAL_XOR operator."""
547
548 name = "logical_xor"
549
550
551class MatmulOperator(Operator):
552 """Test selector for the MATMUL operator."""
553
554 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100555 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100556
557
558class MaximumOperator(Operator):
559 """Test selector for the Maximum operator."""
560
561 name = "maximum"
562
563
564class MaxPool2dOperator(Operator):
565 """Test selector for the MAX_POOL2D operator."""
566
567 name = "max_pool2d"
568 param_names = ["shape", "type", "stride", "kernel", "pad"]
569
570
571class MinimumOperator(Operator):
572 """Test selector for the Minimum operator."""
573
574 name = "minimum"
575
576
577class MulOperator(Operator):
578 """Test selector for the MUL operator."""
579
580 name = "mul"
581 param_names = ["shape", "type", "perm", "shift"]
582
583
584class NegateOperator(Operator):
585 """Test selector for the Negate operator."""
586
587 name = "negate"
588
589
590class PadOperator(Operator):
591 """Test selector for the PAD operator."""
592
593 name = "pad"
594 param_names = ["shape", "type", "pad"]
595
596
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000597class PowOperator(Operator):
598 """Test selector for the POW operator."""
599
600 name = "pow"
601
602
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100603class ReduceAllOperator(Operator):
604 """Test selector for the REDUCE_ALL operator."""
605
606 name = "reduce_all"
607 param_names = ["shape", "type", "axis"]
608
609
610class ReduceAnyOperator(Operator):
611 """Test selector for the REDUCE_ANY operator."""
612
613 name = "reduce_any"
614 param_names = ["shape", "type", "axis"]
615
616
617class ReduceMaxOperator(Operator):
618 """Test selector for the REDUCE_MAX operator."""
619
620 name = "reduce_max"
621 param_names = ["shape", "type", "axis"]
622
623
624class ReduceMinOperator(Operator):
625 """Test selector for the REDUCE_MIN operator."""
626
627 name = "reduce_min"
628 param_names = ["shape", "type", "axis"]
629
630
631class ReduceSumOperator(Operator):
632 """Test selector for the REDUCE_SUM operator."""
633
634 name = "reduce_sum"
635 param_names = ["shape", "type", "axis"]
636
637
638class RescaleOperator(Operator):
639 """Test selector for the RESCALE operator."""
640
641 name = "rescale"
642 param_names = [
643 "shape",
644 "type",
645 "output_type",
646 "scale",
647 "double_round",
648 "per_channel",
649 ]
650
651
652class ReshapeOperator(Operator):
653 """Test selector for the RESHAPE operator."""
654
655 name = "reshape"
656 param_names = ["shape", "type", "perm", "rank"]
657
658
659class ResizeOperator(Operator):
660 """Test selector for the RESIZE operator."""
661
662 name = "resize"
663 param_names = [
664 "shape",
665 "type",
666 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100667 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100668 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100669 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100670 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100671 ]
672
673
674class ReverseOperator(Operator):
675 """Test selector for the REVERSE operator."""
676
677 name = "reverse"
678 param_names = ["shape", "type", "axis"]
679
680
681class ScatterOperator(Operator):
682 """Test selector for the SCATTER operator."""
683
684 name = "scatter"
685
686
687class SelectOperator(Operator):
688 """Test selector for the SELECT operator."""
689
690 name = "select"
691
692
James Wardb45db9a2022-12-12 13:02:44 +0000693class SigmoidOperator(Operator):
694 """Test selector for the SIGMOID operator."""
695
696 name = "sigmoid"
697
698
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100699class SliceOperator(Operator):
700 """Test selector for the SLICE operator."""
701
702 name = "slice"
703 param_names = ["shape", "type", "perm"]
704
705
706class SubOperator(Operator):
707 """Test selector for the SUB operator."""
708
709 name = "sub"
710
711
712class TableOperator(Operator):
713 """Test selector for the TABLE operator."""
714
715 name = "table"
716
717
James Wardb45db9a2022-12-12 13:02:44 +0000718class TanhOperator(Operator):
719 """Test selector for the TANH operator."""
720
721 name = "tanh"
722
723
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100724class TileOperator(Operator):
725 """Test selector for the TILE operator."""
726
727 name = "tile"
728 param_names = ["shape", "type", "perm"]
729
730
731class TransposeOperator(Operator):
732 """Test selector for the TRANSPOSE operator."""
733
734 name = "transpose"
735 param_names = ["shape", "type", "perm"]
736
737 @classmethod
738 def get_test_paths(cls, test_dir: Path, negative):
739 """Generate test paths for this operator."""
740 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
741
742
743class TransposeConv2dOperator(Operator):
744 """Test selector for the TRANSPOSE_CONV2D operator."""
745
746 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100747 param_names = [
748 "kernel",
749 "shape",
750 "type",
751 "accum_type",
752 "stride",
753 "pad",
754 "out_shape",
755 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100756
757 def path_params(self, path):
758 """Return a dictionary of params from the test path."""
759 params = super().path_params(path)
760 # out_shape is different for every test case, so ignore it for selection
761 params["out_shape"] = ""
762 return params
763
764
765class WhileLoopOperator(Operator):
766 """Test selector for the WHILE_LOOP operator."""
767
768 name = "while_loop"
769 param_names = ["shape", "type", "cond"]
770
771
772def parse_args():
773 """Parse the arguments."""
774 parser = argparse.ArgumentParser()
775 parser.add_argument(
776 "--test-dir",
777 default=Path.cwd(),
778 type=Path,
779 help=(
780 "The directory where test subdirectories for all operators can be found"
781 " (default: current working directory)"
782 ),
783 )
784 parser.add_argument(
785 "--config",
786 default=Path(__file__).with_suffix(".json"),
787 type=Path,
788 help="A JSON file defining the parameters to use for each operator",
789 )
790 parser.add_argument(
791 "--full-path", action="store_true", help="output the full path for each test"
792 )
793 parser.add_argument(
794 "-v",
795 dest="verbosity",
796 action="count",
797 default=0,
798 help="Verbosity (can be used multiple times for more details)",
799 )
800 parser.add_argument(
801 "operators",
802 type=str,
803 nargs="*",
804 help=(
805 f"Select tests for the specified operator(s)"
806 f" - all operators are assumed if none are specified)"
807 f" - choose from: {[n for n in Operator.registry]}"
808 ),
809 )
810 parser.add_argument(
811 "--test-type",
812 dest="test_type",
813 choices=["positive", "negative"],
814 default="positive",
815 type=str,
816 help="type of tests selected, positive or negative",
817 )
818 return parser.parse_args()
819
820
821def main():
822 """Example test selection."""
823 args = parse_args()
824
825 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +0000826 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100827 logger.info(f"{__file__}: args: {args}")
828
829 try:
830 with open(args.config, "r") as fd:
831 config = json.load(fd)
832 except Exception as e:
833 logger.error(f"Config file error: {e}")
834 return 2
835
836 negative = args.test_type == "negative"
837 for op_name in Operator.registry:
838 if not args.operators or op_name in args.operators:
839 op_params = config[op_name] if op_name in config else {}
840 op = Operator.registry[op_name](
841 args.test_dir, op_params, negative, exclude_types=["float"]
842 )
843 for test_path in op.select_tests():
844 print(test_path.resolve() if args.full_path else test_path.name)
845
846 return 0
847
848
849if __name__ == "__main__":
850 exit(main())