blob: 8b60fbb5094dc0d6ea281b307a5a6b1593101fa5 [file] [log] [blame]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
8from pathlib import Path
9from typing import Any
10from typing import Dict
11from typing import List
12
13logging.basicConfig()
14logger = logging.getLogger("test_select")
15
16
17def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
18 """Generate permuted combinations of a dictionary of values and combine with others.
19
20 permutes: a dictionary with sequences of values to be fully permuted
21 others: a dictionary with sequences of values not fully permuted, but all used
22
23 This yields dictionaries with one value from each of the items in permutes,
24 combined with one value from each of the items in others.
25
26 Example 1:
27
28 permutes = {"a": [1, 2], "b": [3, 4]}
29 others = {"c": [5, 6, 7], "d" [True, False]}
30
31 generates:
32
33 [
34 {"a": 1, "b": 3, "c": 5, "d": True},
35 {"a": 1, "b": 4, "c": 6, "d": False`},
36 {"a": 2, "b": 3, "c": 7, "d": True},
37 {"a": 2, "b": 4, "c": 5, "d": False`},
38 ]
39
40 Example 2:
41
42 permutes = {"a": [1, 2], "b": [3, 4]}
43 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
44
45 generates:
46
47 [
48 {"a": 1, "b": 3, "c": 5, "d": True},
49 {"a": 1, "b": 4, "c": 6, "d": False},
50 {"a": 2, "b": 3, "c": 7, "d": True},
51 {"a": 2, "b": 4, "c": 8, "d": False},
52 {"a": 1, "b": 3, "c": 9, "d": True},
53 ]
54
55 Raises:
56 ValueError if any item is in both permutes and others
57 """
58 for k in permutes:
59 if k in others:
60 raise ValueError(f"item conflict: {k}")
61
62 p_keys = []
63 p_vals = []
64 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
65 # so we return a single, empty dictionary, if others is also empty
66 p_product_len = 1
67 # extract the keys and values from the permutes dictionary
68 # and calulate the product of the number of values in each item as we do so
69 for k, v in permutes.items():
70 p_keys.append(k)
71 p_vals.append(v)
72 p_product_len *= len(v)
73 # create a cyclic generator for the product of all the permuted values
74 p_product = itertools.product(*p_vals)
75 p_generator = itertools.cycle(p_product)
76
77 o_keys = []
78 o_vals = []
79 o_generators = []
80 # extract the keys and values from the others dictionary
81 # and create a cyclic generator for each list of values
82 for k, v in others.items():
83 o_keys.append(k)
84 o_vals.append(v)
85 o_generators.append(itertools.cycle(v))
86
87 # The number of params dictionaries generated will be the maximumum size
88 # of the permuted values and the non-permuted values from others
89 max_items = max([p_product_len] + [len(x) for x in o_vals])
90
91 # create a dictionary with a single value for each of the permutes and others keys
92 for _ in range(max_items):
93 params = {}
94 # add the values for the permutes parameters
95 # the permuted values generator returns a value for each of the permuted keys
96 # in the same order as they were originally given
97 p_vals = next(p_generator)
98 for i in range(len(p_keys)):
99 params[p_keys[i]] = p_vals[i]
100 # add the values for the others parameters
101 # there is a separate generator for each of the others values
102 for i in range(len(o_keys)):
103 params[o_keys[i]] = next(o_generators[i])
104 yield params
105
106
107class Operator:
108 """Base class for operator specific selection properties."""
109
110 # A registry of all Operator subclasses, indexed by the operator name
111 registry = {}
112
113 def __init_subclass__(cls, **kwargs):
114 """Subclass initialiser to register all Operator classes."""
115 super().__init_subclass__(**kwargs)
116 cls.registry[cls.name] = cls
117
118 # Derived classes must override the operator name
119 name = None
120 # Operators with additional parameters must override the param_names
121 # NB: the order must match the order the values appear in the test names
122 param_names = ["shape", "type"]
123
124 # Working set of param_names - updated for negative tests
125 wks_param_names = None
126
127 def __init__(
128 self,
129 test_dir: Path,
130 config: Dict[str, Dict[str, List[Any]]],
131 negative=False,
132 exclude_types=None,
133 ):
134 """Initialise the selection parameters for an operator.
135
136 test_dir: the directory where the tests for all operators can be found
137 config: a dictionary with:
138 "params" - mappings of parameter names to the values to select
139 "permutes" - a list of parameter names to be permuted
140 "errorifs" - list of ERRORIF case names to be selected (negative test)
141 negative: bool indicating if negative testing is being selected (ERRORIF tests)
142 """
143 assert isinstance(
144 self.name, str
145 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
146
147 self.negative = negative
148 self.wks_param_names = self.param_names.copy()
149 if self.negative:
150 # need to override positive set up - use "errorifs" config if set
151 # add in errorif case before shape to support all ops, including
152 # different ops like COND_IF and CONVnD etc
153 index = self.wks_param_names.index("shape")
154 self.wks_param_names[index:index] = ["ERRORIF", "case"]
155 config["params"] = {x: [] for x in self.wks_param_names}
156 config["params"]["case"] = (
157 config["errorifs"] if "errorifs" in config else []
158 )
159 config["permutes"] = []
160 config["preselected"] = {}
161
162 self.params = config["params"] if "params" in config else {}
163 self.permutes = config["permutes"] if "permutes" in config else []
164 self.sparsity = config["sparsity"] if "sparsity" in config else {}
165 self.preselected = config["preselected"] if "preselected" in config else {}
166 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
167 logger.info(f"{self.name}: permutes={self.permutes}")
168 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
169
170 if exclude_types is None:
171 exclude_types = []
172 self.test_paths = [
173 p
174 for p in self.get_test_paths(test_dir, self.negative)
175 # exclusion of types if requested
176 if self.path_params(p)["type"] not in exclude_types
177 ]
178 if not self.test_paths:
179 logger.error(f"no tests found for {self.name} in {test_dir}")
180 logger.debug(f"{self.name}: paths={self.test_paths}")
181
182 # get default parameter values for any not given in the config
183 default_params = self.get_default_params()
184 for param in default_params:
185 if param not in self.params or not self.params[param]:
186 self.params[param] = default_params[param]
187 for param in self.wks_param_names:
188 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
189
190 @staticmethod
191 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
192 """Generate test paths for operators using operator specifics."""
193 for base_dir in sorted(test_dir.glob(base_dir_glob)):
194 for path in sorted(base_dir.glob(path_glob)):
195 if (not negative and "ERRORIF" not in str(path)) or (
196 negative and "ERRORIF" in str(path)
197 ):
198 yield path
199
200 @classmethod
201 def get_test_paths(cls, test_dir: Path, negative):
202 """Generate test paths for this operator."""
203 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
204
205 def path_params(self, path):
206 """Return a dictionary of params from the test path."""
207 params = {}
208 op_name_parts = self.name.split("_")
209 values = path.name.split("_")[len(op_name_parts) :]
210 assert len(values) == len(
211 self.wks_param_names
212 ), f"len({values}) == len({self.wks_param_names})"
213 for i, param in enumerate(self.wks_param_names):
214 params[param] = values[i]
215 return params
216
217 def get_default_params(self):
218 """Get the default parameter values from the test names."""
219 params = {param: set() for param in self.wks_param_names}
220 for path in self.test_paths:
221 path_params = self.path_params(path)
222 for k in params:
223 params[k].add(path_params[k])
224 for param in params:
225 params[param] = sorted(list(params[param]))
226 return params
227
228 def select_tests(self): # noqa: C901 (function too complex)
229 """Generate the paths to the selected tests for this operator."""
230 if not self.test_paths:
231 # Exit early when nothing to select from
232 return
233
234 # the test paths that have not been selected yet
235 unused_paths = set(self.test_paths)
236
237 # a list of dictionaries of unused preselected parameter combinations
238 unused_preselected = [x for x in self.preselected]
239 logger.debug(f"preselected: {unused_preselected}")
240
241 # a list of dictionaries of unused permuted parameter combinations
242 permutes = {k: self.params[k] for k in self.permutes}
243 others = {k: self.params[k] for k in self.non_permutes}
244 unused_permuted = [x for x in expand_params(permutes, others)]
245 logger.debug(f"permuted: {unused_permuted}")
246
247 # a dictionary of sets of unused parameter values
248 if self.negative:
249 # We only care about selecting a test for each errorif case
250 unused_values = {k: set() for k in self.params}
251 unused_values["case"] = set(self.params["case"])
252 else:
253 unused_values = {k: set(v) for k, v in self.params.items()}
254
255 # select tests matching permuted, or preselected, parameter combinations
256 for path in self.test_paths:
257 path_params = self.path_params(path)
258 if path_params in unused_permuted or path_params in unused_preselected:
259 unused_paths.remove(path)
260 if path_params in unused_preselected:
261 unused_preselected.remove(path_params)
262 if path_params in unused_permuted:
263 unused_permuted.remove(path_params)
264 if self.negative:
265 # remove any other errorif cases, so we only match one
266 for p in list(unused_permuted):
267 if p["case"] == path_params["case"]:
268 unused_permuted.remove(p)
269 # remove the param values used by this path
270 for k in path_params:
271 unused_values[k].discard(path_params[k])
272 logger.debug(f"FOUND: {path.name}")
273 yield path
274
275 # search for tests that match any unused parameter values
276 for n, path in enumerate(sorted(list(unused_paths))):
277 path_params = self.path_params(path)
278 # select paths with unused param values
279 # skipping some, if sparsity is set for the param
280 for k in path_params:
281 if path_params[k] in unused_values[k] and (
282 k not in self.sparsity or n % self.sparsity[k] == 0
283 ):
284 # remove the param values used by this path
285 for p in path_params:
286 unused_values[p].discard(path_params[p])
287 logger.debug(f"FOUND: {path.name}")
288 yield path
289 break
290
291 # report any preselected combinations that were not found
292 for params in unused_preselected:
293 logger.warning(f"MISSING preselected: {params}")
294 # report any permuted combinations that were not found
295 for params in unused_permuted:
296 logger.debug(f"MISSING permutation: {params}")
297 # report any param values that were not found
298 for k, values in unused_values.items():
299 if values:
300 if k not in self.sparsity:
301 logger.warning(f"MISSING {len(values)} values for {k}: {values}")
302 else:
303 logger.info(
304 f"Skipped {len(values)} values for {k} due to sparsity setting"
305 )
306 logger.debug(f"Values skipped: {values}")
307
308
309class AbsOperator(Operator):
310 """Test selector for the ABS operator."""
311
312 name = "abs"
313
314
315class ArithmeticRightShiftOperator(Operator):
316 """Test selector for the Arithmetic Right Shift operator."""
317
318 name = "arithmetic_right_shift"
319 param_names = ["shape", "type", "rounding"]
320
321
322class AddOperator(Operator):
323 """Test selector for the ADD operator."""
324
325 name = "add"
326
327
328class ArgmaxOperator(Operator):
329 """Test selector for the ARGMAX operator."""
330
331 name = "argmax"
332 param_names = ["shape", "type", "axis"]
333
334
335class AvgPool2dOperator(Operator):
336 """Test selector for the AVG_POOL2D operator."""
337
338 name = "avg_pool2d"
339 param_names = ["shape", "type", "stride", "kernel", "pad"]
340
341
342class BitwiseAndOperator(Operator):
343 """Test selector for the BITWISE_AND operator."""
344
345 name = "bitwise_and"
346
347
348class BitwiseNotOperator(Operator):
349 """Test selector for the BITWISE_NOT operator."""
350
351 name = "bitwise_not"
352
353
354class BitwiseOrOperator(Operator):
355 """Test selector for the BITWISE_OR operator."""
356
357 name = "bitwise_or"
358
359
360class BitwiseXorOperator(Operator):
361 """Test selector for the BITWISE_XOR operator."""
362
363 name = "bitwise_xor"
364
365
366class CastOperator(Operator):
367 """Test selector for the CAST operator."""
368
369 name = "cast"
370 param_names = ["shape", "type", "output_type"]
371
372
373class ClampOperator(Operator):
374 """Test selector for the CLAMP operator."""
375
376 name = "clamp"
377
378
379class CLZOperator(Operator):
380 """Test selector for the CLZ operator."""
381
382 name = "clz"
383 param_names = ["shape", "type"]
384
385
386class ConcatOperator(Operator):
387 """Test selector for the CONCAT operator."""
388
389 name = "concat"
390 param_names = ["shape", "type", "axis"]
391
392
393class CondIfOperator(Operator):
394 """Test selector for the COND_IF operator."""
395
396 name = "cond_if"
397 param_names = ["variant", "shape", "type", "cond"]
398
399
400class ConstOperator(Operator):
401 """Test selector for the CONST operator."""
402
403 name = "const"
404
405
406class Conv2dOperator(Operator):
407 """Test selector for the CONV2D operator."""
408
409 name = "conv2d"
410 param_names = ["kernel", "shape", "type", "stride", "pad", "dilation"]
411
412
413class Conv3dOperator(Operator):
414 """Test selector for the CONV3D operator."""
415
416 name = "conv3d"
417 param_names = ["kernel", "shape", "type", "stride", "pad", "dilation"]
418
419
420class DepthwiseConv2dOperator(Operator):
421 """Test selector for the DEPTHWISE_CONV2D operator."""
422
423 name = "depthwise_conv2d"
424 param_names = ["kernel", "shape", "type", "stride", "pad", "dilation"]
425
426
427class EqualOperator(Operator):
428 """Test selector for the EQUAL operator."""
429
430 name = "equal"
431
432
433class FullyConnectedOperator(Operator):
434 """Test selector for the FULLY_CONNECTED operator."""
435
436 name = "fully_connected"
437
438
439class GatherOperator(Operator):
440 """Test selector for the GATHER operator."""
441
442 name = "gather"
443
444
445class GreaterOperator(Operator):
446 """Test selector for the GREATER operator."""
447
448 name = "greater"
449
450 @classmethod
451 def get_test_paths(cls, test_dir: Path, negative):
452 """Generate test paths for this operator."""
453 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
454
455
456class GreaterEqualOperator(Operator):
457 """Test selector for the GREATER_EQUAL operator."""
458
459 name = "greater_equal"
460
461
462class IdentityOperator(Operator):
463 """Test selector for the IDENTITY operator."""
464
465 name = "identity"
466
467
468class IntDivOperator(Operator):
469 """Test selector for the INTDIV."""
470
471 name = "intdiv"
472
473
474class LogicalAndOperator(Operator):
475 """Test selector for the LOGICAL_AND operator."""
476
477 name = "logical_and"
478
479
480class LogicalLeftShiftOperator(Operator):
481 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
482
483 name = "logical_left_shift"
484
485
486class LogicalNotOperator(Operator):
487 """Test selector for the LOGICAL_NOT operator."""
488
489 name = "logical_not"
490
491
492class LogicalOrOperator(Operator):
493 """Test selector for the LOGICAL_OR operator."""
494
495 name = "logical_or"
496
497
498class LogicalRightShiftOperator(Operator):
499 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
500
501 name = "logical_right_shift"
502
503
504class LogicalXorOperator(Operator):
505 """Test selector for the LOGICAL_XOR operator."""
506
507 name = "logical_xor"
508
509
510class MatmulOperator(Operator):
511 """Test selector for the MATMUL operator."""
512
513 name = "matmul"
514
515
516class MaximumOperator(Operator):
517 """Test selector for the Maximum operator."""
518
519 name = "maximum"
520
521
522class MaxPool2dOperator(Operator):
523 """Test selector for the MAX_POOL2D operator."""
524
525 name = "max_pool2d"
526 param_names = ["shape", "type", "stride", "kernel", "pad"]
527
528
529class MinimumOperator(Operator):
530 """Test selector for the Minimum operator."""
531
532 name = "minimum"
533
534
535class MulOperator(Operator):
536 """Test selector for the MUL operator."""
537
538 name = "mul"
539 param_names = ["shape", "type", "perm", "shift"]
540
541
542class NegateOperator(Operator):
543 """Test selector for the Negate operator."""
544
545 name = "negate"
546
547
548class PadOperator(Operator):
549 """Test selector for the PAD operator."""
550
551 name = "pad"
552 param_names = ["shape", "type", "pad"]
553
554
555class ReduceAllOperator(Operator):
556 """Test selector for the REDUCE_ALL operator."""
557
558 name = "reduce_all"
559 param_names = ["shape", "type", "axis"]
560
561
562class ReduceAnyOperator(Operator):
563 """Test selector for the REDUCE_ANY operator."""
564
565 name = "reduce_any"
566 param_names = ["shape", "type", "axis"]
567
568
569class ReduceMaxOperator(Operator):
570 """Test selector for the REDUCE_MAX operator."""
571
572 name = "reduce_max"
573 param_names = ["shape", "type", "axis"]
574
575
576class ReduceMinOperator(Operator):
577 """Test selector for the REDUCE_MIN operator."""
578
579 name = "reduce_min"
580 param_names = ["shape", "type", "axis"]
581
582
583class ReduceSumOperator(Operator):
584 """Test selector for the REDUCE_SUM operator."""
585
586 name = "reduce_sum"
587 param_names = ["shape", "type", "axis"]
588
589
590class RescaleOperator(Operator):
591 """Test selector for the RESCALE operator."""
592
593 name = "rescale"
594 param_names = [
595 "shape",
596 "type",
597 "output_type",
598 "scale",
599 "double_round",
600 "per_channel",
601 ]
602
603
604class ReshapeOperator(Operator):
605 """Test selector for the RESHAPE operator."""
606
607 name = "reshape"
608 param_names = ["shape", "type", "perm", "rank"]
609
610
611class ResizeOperator(Operator):
612 """Test selector for the RESIZE operator."""
613
614 name = "resize"
615 param_names = [
616 "shape",
617 "type",
618 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100619 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100620 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100621 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100622 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100623 ]
624
625
626class ReverseOperator(Operator):
627 """Test selector for the REVERSE operator."""
628
629 name = "reverse"
630 param_names = ["shape", "type", "axis"]
631
632
633class ScatterOperator(Operator):
634 """Test selector for the SCATTER operator."""
635
636 name = "scatter"
637
638
639class SelectOperator(Operator):
640 """Test selector for the SELECT operator."""
641
642 name = "select"
643
644
645class SliceOperator(Operator):
646 """Test selector for the SLICE operator."""
647
648 name = "slice"
649 param_names = ["shape", "type", "perm"]
650
651
652class SubOperator(Operator):
653 """Test selector for the SUB operator."""
654
655 name = "sub"
656
657
658class TableOperator(Operator):
659 """Test selector for the TABLE operator."""
660
661 name = "table"
662
663
664class TileOperator(Operator):
665 """Test selector for the TILE operator."""
666
667 name = "tile"
668 param_names = ["shape", "type", "perm"]
669
670
671class TransposeOperator(Operator):
672 """Test selector for the TRANSPOSE operator."""
673
674 name = "transpose"
675 param_names = ["shape", "type", "perm"]
676
677 @classmethod
678 def get_test_paths(cls, test_dir: Path, negative):
679 """Generate test paths for this operator."""
680 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
681
682
683class TransposeConv2dOperator(Operator):
684 """Test selector for the TRANSPOSE_CONV2D operator."""
685
686 name = "transpose_conv2d"
687 param_names = ["kernel", "shape", "type", "stride", "pad", "out_shape"]
688
689 def path_params(self, path):
690 """Return a dictionary of params from the test path."""
691 params = super().path_params(path)
692 # out_shape is different for every test case, so ignore it for selection
693 params["out_shape"] = ""
694 return params
695
696
697class WhileLoopOperator(Operator):
698 """Test selector for the WHILE_LOOP operator."""
699
700 name = "while_loop"
701 param_names = ["shape", "type", "cond"]
702
703
704def parse_args():
705 """Parse the arguments."""
706 parser = argparse.ArgumentParser()
707 parser.add_argument(
708 "--test-dir",
709 default=Path.cwd(),
710 type=Path,
711 help=(
712 "The directory where test subdirectories for all operators can be found"
713 " (default: current working directory)"
714 ),
715 )
716 parser.add_argument(
717 "--config",
718 default=Path(__file__).with_suffix(".json"),
719 type=Path,
720 help="A JSON file defining the parameters to use for each operator",
721 )
722 parser.add_argument(
723 "--full-path", action="store_true", help="output the full path for each test"
724 )
725 parser.add_argument(
726 "-v",
727 dest="verbosity",
728 action="count",
729 default=0,
730 help="Verbosity (can be used multiple times for more details)",
731 )
732 parser.add_argument(
733 "operators",
734 type=str,
735 nargs="*",
736 help=(
737 f"Select tests for the specified operator(s)"
738 f" - all operators are assumed if none are specified)"
739 f" - choose from: {[n for n in Operator.registry]}"
740 ),
741 )
742 parser.add_argument(
743 "--test-type",
744 dest="test_type",
745 choices=["positive", "negative"],
746 default="positive",
747 type=str,
748 help="type of tests selected, positive or negative",
749 )
750 return parser.parse_args()
751
752
753def main():
754 """Example test selection."""
755 args = parse_args()
756
757 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
758 logger.basicConfig(level=loglevels[min(args.verbosity, len(loglevels) - 1)])
759 logger.info(f"{__file__}: args: {args}")
760
761 try:
762 with open(args.config, "r") as fd:
763 config = json.load(fd)
764 except Exception as e:
765 logger.error(f"Config file error: {e}")
766 return 2
767
768 negative = args.test_type == "negative"
769 for op_name in Operator.registry:
770 if not args.operators or op_name in args.operators:
771 op_params = config[op_name] if op_name in config else {}
772 op = Operator.registry[op_name](
773 args.test_dir, op_params, negative, exclude_types=["float"]
774 )
775 for test_path in op.select_tests():
776 print(test_path.resolve() if args.full_path else test_path.name)
777
778 return 0
779
780
781if __name__ == "__main__":
782 exit(main())