blob: 05f6db8a214af0662b3bfd3dde3775413752c69f [file] [log] [blame]
Jeremy Johnson35396f22023-01-04 17:05:25 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01002# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
8from pathlib import Path
9from typing import Any
10from typing import Dict
11from typing import List
12
13logging.basicConfig()
14logger = logging.getLogger("test_select")
15
16
17def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
18 """Generate permuted combinations of a dictionary of values and combine with others.
19
20 permutes: a dictionary with sequences of values to be fully permuted
21 others: a dictionary with sequences of values not fully permuted, but all used
22
23 This yields dictionaries with one value from each of the items in permutes,
24 combined with one value from each of the items in others.
25
26 Example 1:
27
28 permutes = {"a": [1, 2], "b": [3, 4]}
29 others = {"c": [5, 6, 7], "d" [True, False]}
30
31 generates:
32
33 [
34 {"a": 1, "b": 3, "c": 5, "d": True},
35 {"a": 1, "b": 4, "c": 6, "d": False`},
36 {"a": 2, "b": 3, "c": 7, "d": True},
37 {"a": 2, "b": 4, "c": 5, "d": False`},
38 ]
39
40 Example 2:
41
42 permutes = {"a": [1, 2], "b": [3, 4]}
43 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
44
45 generates:
46
47 [
48 {"a": 1, "b": 3, "c": 5, "d": True},
49 {"a": 1, "b": 4, "c": 6, "d": False},
50 {"a": 2, "b": 3, "c": 7, "d": True},
51 {"a": 2, "b": 4, "c": 8, "d": False},
52 {"a": 1, "b": 3, "c": 9, "d": True},
53 ]
54
55 Raises:
56 ValueError if any item is in both permutes and others
57 """
58 for k in permutes:
59 if k in others:
60 raise ValueError(f"item conflict: {k}")
61
62 p_keys = []
63 p_vals = []
64 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
65 # so we return a single, empty dictionary, if others is also empty
66 p_product_len = 1
67 # extract the keys and values from the permutes dictionary
68 # and calulate the product of the number of values in each item as we do so
69 for k, v in permutes.items():
70 p_keys.append(k)
71 p_vals.append(v)
72 p_product_len *= len(v)
73 # create a cyclic generator for the product of all the permuted values
74 p_product = itertools.product(*p_vals)
75 p_generator = itertools.cycle(p_product)
76
77 o_keys = []
78 o_vals = []
79 o_generators = []
80 # extract the keys and values from the others dictionary
81 # and create a cyclic generator for each list of values
82 for k, v in others.items():
83 o_keys.append(k)
84 o_vals.append(v)
85 o_generators.append(itertools.cycle(v))
86
87 # The number of params dictionaries generated will be the maximumum size
88 # of the permuted values and the non-permuted values from others
89 max_items = max([p_product_len] + [len(x) for x in o_vals])
90
91 # create a dictionary with a single value for each of the permutes and others keys
92 for _ in range(max_items):
93 params = {}
94 # add the values for the permutes parameters
95 # the permuted values generator returns a value for each of the permuted keys
96 # in the same order as they were originally given
97 p_vals = next(p_generator)
98 for i in range(len(p_keys)):
99 params[p_keys[i]] = p_vals[i]
100 # add the values for the others parameters
101 # there is a separate generator for each of the others values
102 for i in range(len(o_keys)):
103 params[o_keys[i]] = next(o_generators[i])
104 yield params
105
106
107class Operator:
108 """Base class for operator specific selection properties."""
109
110 # A registry of all Operator subclasses, indexed by the operator name
111 registry = {}
112
113 def __init_subclass__(cls, **kwargs):
114 """Subclass initialiser to register all Operator classes."""
115 super().__init_subclass__(**kwargs)
116 cls.registry[cls.name] = cls
117
118 # Derived classes must override the operator name
119 name = None
120 # Operators with additional parameters must override the param_names
121 # NB: the order must match the order the values appear in the test names
122 param_names = ["shape", "type"]
123
124 # Working set of param_names - updated for negative tests
125 wks_param_names = None
126
127 def __init__(
128 self,
129 test_dir: Path,
130 config: Dict[str, Dict[str, List[Any]]],
131 negative=False,
132 exclude_types=None,
133 ):
134 """Initialise the selection parameters for an operator.
135
136 test_dir: the directory where the tests for all operators can be found
137 config: a dictionary with:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100138 "params" - a dictionary with mappings of parameter names to the values
139 to select (a sub-set of expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100140 "permutes" - a list of parameter names to be permuted
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100141 "preselected" - a list of dictionaries containing parameter names and
142 pre-chosen values
143 "sparsity" - a dictionary of parameter names with a sparsity value
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100144 "errorifs" - list of ERRORIF case names to be selected (negative test)
145 negative: bool indicating if negative testing is being selected (ERRORIF tests)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100146
147 EXAMPLE CONFIG:
148 "params": {
149 "output_type": [
150 "outi8",
151 "outb"
152 ]
153 },
154 "permutes": [
155 "shape",
156 "type"
157 ],
158 "sparsity": {
159 "pad": 15
160 },
161 "preselected": [
162 {
163 "shape": "6",
164 "type": "i8",
165 "pad": "pad00"
166 }
167 ],
168 "errorifs": [
169 "InputZeroPointNotZero"
170 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100171 """
172 assert isinstance(
173 self.name, str
174 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
175
176 self.negative = negative
177 self.wks_param_names = self.param_names.copy()
178 if self.negative:
179 # need to override positive set up - use "errorifs" config if set
180 # add in errorif case before shape to support all ops, including
181 # different ops like COND_IF and CONVnD etc
182 index = self.wks_param_names.index("shape")
183 self.wks_param_names[index:index] = ["ERRORIF", "case"]
184 config["params"] = {x: [] for x in self.wks_param_names}
185 config["params"]["case"] = (
186 config["errorifs"] if "errorifs" in config else []
187 )
188 config["permutes"] = []
189 config["preselected"] = {}
190
191 self.params = config["params"] if "params" in config else {}
192 self.permutes = config["permutes"] if "permutes" in config else []
193 self.sparsity = config["sparsity"] if "sparsity" in config else {}
194 self.preselected = config["preselected"] if "preselected" in config else {}
195 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
196 logger.info(f"{self.name}: permutes={self.permutes}")
197 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
198
199 if exclude_types is None:
200 exclude_types = []
201 self.test_paths = [
202 p
203 for p in self.get_test_paths(test_dir, self.negative)
204 # exclusion of types if requested
205 if self.path_params(p)["type"] not in exclude_types
206 ]
207 if not self.test_paths:
208 logger.error(f"no tests found for {self.name} in {test_dir}")
209 logger.debug(f"{self.name}: paths={self.test_paths}")
210
211 # get default parameter values for any not given in the config
212 default_params = self.get_default_params()
213 for param in default_params:
214 if param not in self.params or not self.params[param]:
215 self.params[param] = default_params[param]
216 for param in self.wks_param_names:
217 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
218
219 @staticmethod
220 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
221 """Generate test paths for operators using operator specifics."""
222 for base_dir in sorted(test_dir.glob(base_dir_glob)):
223 for path in sorted(base_dir.glob(path_glob)):
224 if (not negative and "ERRORIF" not in str(path)) or (
225 negative and "ERRORIF" in str(path)
226 ):
227 yield path
228
229 @classmethod
230 def get_test_paths(cls, test_dir: Path, negative):
231 """Generate test paths for this operator."""
232 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
233
234 def path_params(self, path):
235 """Return a dictionary of params from the test path."""
236 params = {}
237 op_name_parts = self.name.split("_")
238 values = path.name.split("_")[len(op_name_parts) :]
239 assert len(values) == len(
240 self.wks_param_names
241 ), f"len({values}) == len({self.wks_param_names})"
242 for i, param in enumerate(self.wks_param_names):
243 params[param] = values[i]
244 return params
245
246 def get_default_params(self):
247 """Get the default parameter values from the test names."""
248 params = {param: set() for param in self.wks_param_names}
249 for path in self.test_paths:
250 path_params = self.path_params(path)
251 for k in params:
252 params[k].add(path_params[k])
253 for param in params:
254 params[param] = sorted(list(params[param]))
255 return params
256
257 def select_tests(self): # noqa: C901 (function too complex)
258 """Generate the paths to the selected tests for this operator."""
259 if not self.test_paths:
260 # Exit early when nothing to select from
261 return
262
263 # the test paths that have not been selected yet
264 unused_paths = set(self.test_paths)
265
266 # a list of dictionaries of unused preselected parameter combinations
267 unused_preselected = [x for x in self.preselected]
268 logger.debug(f"preselected: {unused_preselected}")
269
270 # a list of dictionaries of unused permuted parameter combinations
271 permutes = {k: self.params[k] for k in self.permutes}
272 others = {k: self.params[k] for k in self.non_permutes}
273 unused_permuted = [x for x in expand_params(permutes, others)]
274 logger.debug(f"permuted: {unused_permuted}")
275
276 # a dictionary of sets of unused parameter values
277 if self.negative:
278 # We only care about selecting a test for each errorif case
279 unused_values = {k: set() for k in self.params}
280 unused_values["case"] = set(self.params["case"])
281 else:
282 unused_values = {k: set(v) for k, v in self.params.items()}
283
284 # select tests matching permuted, or preselected, parameter combinations
285 for path in self.test_paths:
286 path_params = self.path_params(path)
287 if path_params in unused_permuted or path_params in unused_preselected:
288 unused_paths.remove(path)
289 if path_params in unused_preselected:
290 unused_preselected.remove(path_params)
291 if path_params in unused_permuted:
292 unused_permuted.remove(path_params)
293 if self.negative:
294 # remove any other errorif cases, so we only match one
295 for p in list(unused_permuted):
296 if p["case"] == path_params["case"]:
297 unused_permuted.remove(p)
298 # remove the param values used by this path
299 for k in path_params:
300 unused_values[k].discard(path_params[k])
301 logger.debug(f"FOUND: {path.name}")
302 yield path
303
304 # search for tests that match any unused parameter values
305 for n, path in enumerate(sorted(list(unused_paths))):
306 path_params = self.path_params(path)
307 # select paths with unused param values
308 # skipping some, if sparsity is set for the param
309 for k in path_params:
310 if path_params[k] in unused_values[k] and (
311 k not in self.sparsity or n % self.sparsity[k] == 0
312 ):
313 # remove the param values used by this path
314 for p in path_params:
315 unused_values[p].discard(path_params[p])
316 logger.debug(f"FOUND: {path.name}")
317 yield path
318 break
319
320 # report any preselected combinations that were not found
321 for params in unused_preselected:
322 logger.warning(f"MISSING preselected: {params}")
323 # report any permuted combinations that were not found
324 for params in unused_permuted:
325 logger.debug(f"MISSING permutation: {params}")
326 # report any param values that were not found
327 for k, values in unused_values.items():
328 if values:
329 if k not in self.sparsity:
330 logger.warning(f"MISSING {len(values)} values for {k}: {values}")
331 else:
332 logger.info(
333 f"Skipped {len(values)} values for {k} due to sparsity setting"
334 )
335 logger.debug(f"Values skipped: {values}")
336
337
338class AbsOperator(Operator):
339 """Test selector for the ABS operator."""
340
341 name = "abs"
342
343
344class ArithmeticRightShiftOperator(Operator):
345 """Test selector for the Arithmetic Right Shift operator."""
346
347 name = "arithmetic_right_shift"
348 param_names = ["shape", "type", "rounding"]
349
350
351class AddOperator(Operator):
352 """Test selector for the ADD operator."""
353
354 name = "add"
355
356
357class ArgmaxOperator(Operator):
358 """Test selector for the ARGMAX operator."""
359
360 name = "argmax"
361 param_names = ["shape", "type", "axis"]
362
363
364class AvgPool2dOperator(Operator):
365 """Test selector for the AVG_POOL2D operator."""
366
367 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100368 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100369
370
371class BitwiseAndOperator(Operator):
372 """Test selector for the BITWISE_AND operator."""
373
374 name = "bitwise_and"
375
376
377class BitwiseNotOperator(Operator):
378 """Test selector for the BITWISE_NOT operator."""
379
380 name = "bitwise_not"
381
382
383class BitwiseOrOperator(Operator):
384 """Test selector for the BITWISE_OR operator."""
385
386 name = "bitwise_or"
387
388
389class BitwiseXorOperator(Operator):
390 """Test selector for the BITWISE_XOR operator."""
391
392 name = "bitwise_xor"
393
394
395class CastOperator(Operator):
396 """Test selector for the CAST operator."""
397
398 name = "cast"
399 param_names = ["shape", "type", "output_type"]
400
401
James Ward71616fe2022-11-23 11:00:47 +0000402class CeilOperator(Operator):
403 """Test selector for the CEIL operator."""
404
405 name = "ceil"
406
407
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100408class ClampOperator(Operator):
409 """Test selector for the CLAMP operator."""
410
411 name = "clamp"
412
413
414class CLZOperator(Operator):
415 """Test selector for the CLZ operator."""
416
417 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100418
419
420class ConcatOperator(Operator):
421 """Test selector for the CONCAT operator."""
422
423 name = "concat"
424 param_names = ["shape", "type", "axis"]
425
426
427class CondIfOperator(Operator):
428 """Test selector for the COND_IF operator."""
429
430 name = "cond_if"
431 param_names = ["variant", "shape", "type", "cond"]
432
433
434class ConstOperator(Operator):
435 """Test selector for the CONST operator."""
436
437 name = "const"
438
439
440class Conv2dOperator(Operator):
441 """Test selector for the CONV2D operator."""
442
443 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100444 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100445
446
447class Conv3dOperator(Operator):
448 """Test selector for the CONV3D operator."""
449
450 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100451 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100452
453
454class DepthwiseConv2dOperator(Operator):
455 """Test selector for the DEPTHWISE_CONV2D operator."""
456
457 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100458 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100459
460
461class EqualOperator(Operator):
462 """Test selector for the EQUAL operator."""
463
464 name = "equal"
465
466
Jeremy Johnson35396f22023-01-04 17:05:25 +0000467class ExpOperator(Operator):
468 """Test selector for the EXP operator."""
469
470 name = "exp"
471
472
James Ward71616fe2022-11-23 11:00:47 +0000473class FloorOperator(Operator):
474 """Test selector for the FLOOR operator."""
475
476 name = "floor"
477
478
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100479class FullyConnectedOperator(Operator):
480 """Test selector for the FULLY_CONNECTED operator."""
481
482 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100483 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100484
485
486class GatherOperator(Operator):
487 """Test selector for the GATHER operator."""
488
489 name = "gather"
490
491
492class GreaterOperator(Operator):
493 """Test selector for the GREATER operator."""
494
495 name = "greater"
496
497 @classmethod
498 def get_test_paths(cls, test_dir: Path, negative):
499 """Generate test paths for this operator."""
500 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
501
502
503class GreaterEqualOperator(Operator):
504 """Test selector for the GREATER_EQUAL operator."""
505
506 name = "greater_equal"
507
508
509class IdentityOperator(Operator):
510 """Test selector for the IDENTITY operator."""
511
512 name = "identity"
513
514
515class IntDivOperator(Operator):
Jeremy Johnson35396f22023-01-04 17:05:25 +0000516 """Test selector for the INTDIV operator."""
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100517
518 name = "intdiv"
519
520
Jeremy Johnson35396f22023-01-04 17:05:25 +0000521class LogOperator(Operator):
522 """Test selector for the LOG operator."""
523
524 name = "log"
525
526
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100527class LogicalAndOperator(Operator):
528 """Test selector for the LOGICAL_AND operator."""
529
530 name = "logical_and"
531
532
533class LogicalLeftShiftOperator(Operator):
534 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
535
536 name = "logical_left_shift"
537
538
539class LogicalNotOperator(Operator):
540 """Test selector for the LOGICAL_NOT operator."""
541
542 name = "logical_not"
543
544
545class LogicalOrOperator(Operator):
546 """Test selector for the LOGICAL_OR operator."""
547
548 name = "logical_or"
549
550
551class LogicalRightShiftOperator(Operator):
552 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
553
554 name = "logical_right_shift"
555
556
557class LogicalXorOperator(Operator):
558 """Test selector for the LOGICAL_XOR operator."""
559
560 name = "logical_xor"
561
562
563class MatmulOperator(Operator):
564 """Test selector for the MATMUL operator."""
565
566 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100567 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100568
569
570class MaximumOperator(Operator):
571 """Test selector for the Maximum operator."""
572
573 name = "maximum"
574
575
576class MaxPool2dOperator(Operator):
577 """Test selector for the MAX_POOL2D operator."""
578
579 name = "max_pool2d"
580 param_names = ["shape", "type", "stride", "kernel", "pad"]
581
582
583class MinimumOperator(Operator):
584 """Test selector for the Minimum operator."""
585
586 name = "minimum"
587
588
589class MulOperator(Operator):
590 """Test selector for the MUL operator."""
591
592 name = "mul"
593 param_names = ["shape", "type", "perm", "shift"]
594
595
596class NegateOperator(Operator):
597 """Test selector for the Negate operator."""
598
599 name = "negate"
600
601
602class PadOperator(Operator):
603 """Test selector for the PAD operator."""
604
605 name = "pad"
606 param_names = ["shape", "type", "pad"]
607
608
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000609class PowOperator(Operator):
610 """Test selector for the POW operator."""
611
612 name = "pow"
613
614
Jeremy Johnson35396f22023-01-04 17:05:25 +0000615class ReciprocalOperator(Operator):
616 """Test selector for the RECIPROCAL operator."""
617
618 name = "reciprocal"
619
620
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100621class ReduceAllOperator(Operator):
622 """Test selector for the REDUCE_ALL operator."""
623
624 name = "reduce_all"
625 param_names = ["shape", "type", "axis"]
626
627
628class ReduceAnyOperator(Operator):
629 """Test selector for the REDUCE_ANY operator."""
630
631 name = "reduce_any"
632 param_names = ["shape", "type", "axis"]
633
634
635class ReduceMaxOperator(Operator):
636 """Test selector for the REDUCE_MAX operator."""
637
638 name = "reduce_max"
639 param_names = ["shape", "type", "axis"]
640
641
642class ReduceMinOperator(Operator):
643 """Test selector for the REDUCE_MIN operator."""
644
645 name = "reduce_min"
646 param_names = ["shape", "type", "axis"]
647
648
649class ReduceSumOperator(Operator):
650 """Test selector for the REDUCE_SUM operator."""
651
652 name = "reduce_sum"
653 param_names = ["shape", "type", "axis"]
654
655
656class RescaleOperator(Operator):
657 """Test selector for the RESCALE operator."""
658
659 name = "rescale"
660 param_names = [
661 "shape",
662 "type",
663 "output_type",
664 "scale",
665 "double_round",
666 "per_channel",
667 ]
668
669
670class ReshapeOperator(Operator):
671 """Test selector for the RESHAPE operator."""
672
673 name = "reshape"
674 param_names = ["shape", "type", "perm", "rank"]
675
676
677class ResizeOperator(Operator):
678 """Test selector for the RESIZE operator."""
679
680 name = "resize"
681 param_names = [
682 "shape",
683 "type",
684 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100685 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100686 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100687 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100688 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100689 ]
690
691
692class ReverseOperator(Operator):
693 """Test selector for the REVERSE operator."""
694
695 name = "reverse"
696 param_names = ["shape", "type", "axis"]
697
698
Jeremy Johnson35396f22023-01-04 17:05:25 +0000699class RsqrtOperator(Operator):
700 """Test selector for the RSQRT operator."""
701
702 name = "rsqrt"
703
704
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100705class ScatterOperator(Operator):
706 """Test selector for the SCATTER operator."""
707
708 name = "scatter"
709
710
711class SelectOperator(Operator):
712 """Test selector for the SELECT operator."""
713
714 name = "select"
715
716
James Wardb45db9a2022-12-12 13:02:44 +0000717class SigmoidOperator(Operator):
718 """Test selector for the SIGMOID operator."""
719
720 name = "sigmoid"
721
722
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100723class SliceOperator(Operator):
724 """Test selector for the SLICE operator."""
725
726 name = "slice"
727 param_names = ["shape", "type", "perm"]
728
729
730class SubOperator(Operator):
731 """Test selector for the SUB operator."""
732
733 name = "sub"
734
735
736class TableOperator(Operator):
737 """Test selector for the TABLE operator."""
738
739 name = "table"
740
741
James Wardb45db9a2022-12-12 13:02:44 +0000742class TanhOperator(Operator):
743 """Test selector for the TANH operator."""
744
745 name = "tanh"
746
747
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100748class TileOperator(Operator):
749 """Test selector for the TILE operator."""
750
751 name = "tile"
752 param_names = ["shape", "type", "perm"]
753
754
755class TransposeOperator(Operator):
756 """Test selector for the TRANSPOSE operator."""
757
758 name = "transpose"
759 param_names = ["shape", "type", "perm"]
760
761 @classmethod
762 def get_test_paths(cls, test_dir: Path, negative):
763 """Generate test paths for this operator."""
764 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
765
766
767class TransposeConv2dOperator(Operator):
768 """Test selector for the TRANSPOSE_CONV2D operator."""
769
770 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100771 param_names = [
772 "kernel",
773 "shape",
774 "type",
775 "accum_type",
776 "stride",
777 "pad",
778 "out_shape",
779 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100780
781 def path_params(self, path):
782 """Return a dictionary of params from the test path."""
783 params = super().path_params(path)
784 # out_shape is different for every test case, so ignore it for selection
785 params["out_shape"] = ""
786 return params
787
788
789class WhileLoopOperator(Operator):
790 """Test selector for the WHILE_LOOP operator."""
791
792 name = "while_loop"
793 param_names = ["shape", "type", "cond"]
794
795
796def parse_args():
797 """Parse the arguments."""
798 parser = argparse.ArgumentParser()
799 parser.add_argument(
800 "--test-dir",
801 default=Path.cwd(),
802 type=Path,
803 help=(
804 "The directory where test subdirectories for all operators can be found"
805 " (default: current working directory)"
806 ),
807 )
808 parser.add_argument(
809 "--config",
810 default=Path(__file__).with_suffix(".json"),
811 type=Path,
812 help="A JSON file defining the parameters to use for each operator",
813 )
814 parser.add_argument(
815 "--full-path", action="store_true", help="output the full path for each test"
816 )
817 parser.add_argument(
818 "-v",
819 dest="verbosity",
820 action="count",
821 default=0,
822 help="Verbosity (can be used multiple times for more details)",
823 )
824 parser.add_argument(
825 "operators",
826 type=str,
827 nargs="*",
828 help=(
829 f"Select tests for the specified operator(s)"
830 f" - all operators are assumed if none are specified)"
831 f" - choose from: {[n for n in Operator.registry]}"
832 ),
833 )
834 parser.add_argument(
835 "--test-type",
836 dest="test_type",
837 choices=["positive", "negative"],
838 default="positive",
839 type=str,
840 help="type of tests selected, positive or negative",
841 )
842 return parser.parse_args()
843
844
845def main():
846 """Example test selection."""
847 args = parse_args()
848
849 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +0000850 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100851 logger.info(f"{__file__}: args: {args}")
852
853 try:
854 with open(args.config, "r") as fd:
855 config = json.load(fd)
856 except Exception as e:
857 logger.error(f"Config file error: {e}")
858 return 2
859
860 negative = args.test_type == "negative"
861 for op_name in Operator.registry:
862 if not args.operators or op_name in args.operators:
863 op_params = config[op_name] if op_name in config else {}
864 op = Operator.registry[op_name](
865 args.test_dir, op_params, negative, exclude_types=["float"]
866 )
867 for test_path in op.select_tests():
868 print(test_path.resolve() if args.full_path else test_path.name)
869
870 return 0
871
872
873if __name__ == "__main__":
874 exit(main())