blob: 061a109076978ea91ad92cbebe607a193d70d6f3 [file] [log] [blame]
Jeremy Johnson35396f22023-01-04 17:05:25 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01002# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
James Ward736fd1a2023-01-23 17:13:37 +00008import re
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01009from pathlib import Path
10from typing import Any
11from typing import Dict
12from typing import List
13
14logging.basicConfig()
15logger = logging.getLogger("test_select")
16
17
18def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
19 """Generate permuted combinations of a dictionary of values and combine with others.
20
21 permutes: a dictionary with sequences of values to be fully permuted
22 others: a dictionary with sequences of values not fully permuted, but all used
23
24 This yields dictionaries with one value from each of the items in permutes,
25 combined with one value from each of the items in others.
26
27 Example 1:
28
29 permutes = {"a": [1, 2], "b": [3, 4]}
30 others = {"c": [5, 6, 7], "d" [True, False]}
31
32 generates:
33
34 [
35 {"a": 1, "b": 3, "c": 5, "d": True},
36 {"a": 1, "b": 4, "c": 6, "d": False`},
37 {"a": 2, "b": 3, "c": 7, "d": True},
38 {"a": 2, "b": 4, "c": 5, "d": False`},
39 ]
40
41 Example 2:
42
43 permutes = {"a": [1, 2], "b": [3, 4]}
44 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
45
46 generates:
47
48 [
49 {"a": 1, "b": 3, "c": 5, "d": True},
50 {"a": 1, "b": 4, "c": 6, "d": False},
51 {"a": 2, "b": 3, "c": 7, "d": True},
52 {"a": 2, "b": 4, "c": 8, "d": False},
53 {"a": 1, "b": 3, "c": 9, "d": True},
54 ]
55
56 Raises:
57 ValueError if any item is in both permutes and others
58 """
59 for k in permutes:
60 if k in others:
61 raise ValueError(f"item conflict: {k}")
62
63 p_keys = []
64 p_vals = []
65 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
66 # so we return a single, empty dictionary, if others is also empty
67 p_product_len = 1
68 # extract the keys and values from the permutes dictionary
69 # and calulate the product of the number of values in each item as we do so
70 for k, v in permutes.items():
71 p_keys.append(k)
72 p_vals.append(v)
73 p_product_len *= len(v)
74 # create a cyclic generator for the product of all the permuted values
75 p_product = itertools.product(*p_vals)
76 p_generator = itertools.cycle(p_product)
77
78 o_keys = []
79 o_vals = []
80 o_generators = []
81 # extract the keys and values from the others dictionary
82 # and create a cyclic generator for each list of values
83 for k, v in others.items():
84 o_keys.append(k)
85 o_vals.append(v)
86 o_generators.append(itertools.cycle(v))
87
88 # The number of params dictionaries generated will be the maximumum size
89 # of the permuted values and the non-permuted values from others
90 max_items = max([p_product_len] + [len(x) for x in o_vals])
91
92 # create a dictionary with a single value for each of the permutes and others keys
93 for _ in range(max_items):
94 params = {}
95 # add the values for the permutes parameters
96 # the permuted values generator returns a value for each of the permuted keys
97 # in the same order as they were originally given
98 p_vals = next(p_generator)
99 for i in range(len(p_keys)):
100 params[p_keys[i]] = p_vals[i]
101 # add the values for the others parameters
102 # there is a separate generator for each of the others values
103 for i in range(len(o_keys)):
104 params[o_keys[i]] = next(o_generators[i])
105 yield params
106
107
108class Operator:
109 """Base class for operator specific selection properties."""
110
111 # A registry of all Operator subclasses, indexed by the operator name
112 registry = {}
113
114 def __init_subclass__(cls, **kwargs):
115 """Subclass initialiser to register all Operator classes."""
116 super().__init_subclass__(**kwargs)
117 cls.registry[cls.name] = cls
118
119 # Derived classes must override the operator name
120 name = None
121 # Operators with additional parameters must override the param_names
122 # NB: the order must match the order the values appear in the test names
123 param_names = ["shape", "type"]
124
125 # Working set of param_names - updated for negative tests
126 wks_param_names = None
127
128 def __init__(
129 self,
130 test_dir: Path,
131 config: Dict[str, Dict[str, List[Any]]],
132 negative=False,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100133 ):
134 """Initialise the selection parameters for an operator.
135
James Ward736fd1a2023-01-23 17:13:37 +0000136 test_dir: the directory where the tests for all operators can
137 be found
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100138 config: a dictionary with:
James Ward736fd1a2023-01-23 17:13:37 +0000139 "params" - a dictionary with mappings of parameter
140 names to the values to select (a sub-set of
141 expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100142 "permutes" - a list of parameter names to be permuted
James Ward736fd1a2023-01-23 17:13:37 +0000143 "preselected" - a list of dictionaries containing
144 parameter names and pre-chosen values
145 "sparsity" - a dictionary of parameter names with a
146 sparsity value
147 "exclude_patterns" - a list of regex's whereby each
148 match will not be considered for selection.
149 Exclusion happens BEFORE test selection (i.e.
150 before permutes are applied).
151 "errorifs" - list of ERRORIF case names to be selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000152 after exclusion (negative tests)
James Ward736fd1a2023-01-23 17:13:37 +0000153 negative: bool indicating if negative testing is being selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000154 which filters for ERRORIF in the test name and only selects
155 the first test found (ERRORIF tests)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100156
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000157 EXAMPLE CONFIG (with non-json comments):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 "params": {
159 "output_type": [
160 "outi8",
161 "outb"
162 ]
163 },
164 "permutes": [
165 "shape",
166 "type"
167 ],
168 "sparsity": {
169 "pad": 15
170 },
171 "preselected": [
172 {
173 "shape": "6",
174 "type": "i8",
175 "pad": "pad00"
176 }
177 ],
James Ward736fd1a2023-01-23 17:13:37 +0000178 "exclude_patterns": [
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000179 # Exclude positive (not ERRORIF) integer tests
180 "^((?!ERRORIF).)*_(i8|i16|i32|b)_out(i8|i16|i32|b)",
181 # Exclude negative (ERRORIF) i8 test
182 ".*_ERRORIF_.*_i8_outi8"
James Ward736fd1a2023-01-23 17:13:37 +0000183 ],
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100184 "errorifs": [
185 "InputZeroPointNotZero"
186 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100187 """
188 assert isinstance(
189 self.name, str
190 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
191
192 self.negative = negative
193 self.wks_param_names = self.param_names.copy()
194 if self.negative:
195 # need to override positive set up - use "errorifs" config if set
196 # add in errorif case before shape to support all ops, including
197 # different ops like COND_IF and CONVnD etc
198 index = self.wks_param_names.index("shape")
199 self.wks_param_names[index:index] = ["ERRORIF", "case"]
200 config["params"] = {x: [] for x in self.wks_param_names}
201 config["params"]["case"] = (
202 config["errorifs"] if "errorifs" in config else []
203 )
204 config["permutes"] = []
205 config["preselected"] = {}
206
207 self.params = config["params"] if "params" in config else {}
208 self.permutes = config["permutes"] if "permutes" in config else []
209 self.sparsity = config["sparsity"] if "sparsity" in config else {}
210 self.preselected = config["preselected"] if "preselected" in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000211 self.exclude_patterns = (
212 config["exclude_patterns"] if "exclude_patterns" in config else []
213 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100214 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
215 logger.info(f"{self.name}: permutes={self.permutes}")
216 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
James Ward736fd1a2023-01-23 17:13:37 +0000217 logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100218
James Ward736fd1a2023-01-23 17:13:37 +0000219 self.test_paths = []
220 excluded_paths = []
221 for path in self.get_test_paths(test_dir, self.negative):
222 pattern_match = False
223 for pattern in self.exclude_patterns:
224 if re.fullmatch(pattern, path.name):
225 excluded_paths.append(path)
226 pattern_match = True
227 break
228 if not pattern_match:
229 self.test_paths.append(path)
230
231 logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
232
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100233 if not self.test_paths:
234 logger.error(f"no tests found for {self.name} in {test_dir}")
235 logger.debug(f"{self.name}: paths={self.test_paths}")
236
237 # get default parameter values for any not given in the config
238 default_params = self.get_default_params()
239 for param in default_params:
240 if param not in self.params or not self.params[param]:
241 self.params[param] = default_params[param]
242 for param in self.wks_param_names:
243 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
244
245 @staticmethod
246 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
247 """Generate test paths for operators using operator specifics."""
248 for base_dir in sorted(test_dir.glob(base_dir_glob)):
249 for path in sorted(base_dir.glob(path_glob)):
250 if (not negative and "ERRORIF" not in str(path)) or (
251 negative and "ERRORIF" in str(path)
252 ):
253 yield path
254
255 @classmethod
256 def get_test_paths(cls, test_dir: Path, negative):
257 """Generate test paths for this operator."""
258 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
259
260 def path_params(self, path):
261 """Return a dictionary of params from the test path."""
262 params = {}
263 op_name_parts = self.name.split("_")
264 values = path.name.split("_")[len(op_name_parts) :]
265 assert len(values) == len(
266 self.wks_param_names
267 ), f"len({values}) == len({self.wks_param_names})"
268 for i, param in enumerate(self.wks_param_names):
269 params[param] = values[i]
270 return params
271
272 def get_default_params(self):
273 """Get the default parameter values from the test names."""
274 params = {param: set() for param in self.wks_param_names}
275 for path in self.test_paths:
276 path_params = self.path_params(path)
277 for k in params:
278 params[k].add(path_params[k])
279 for param in params:
280 params[param] = sorted(list(params[param]))
281 return params
282
283 def select_tests(self): # noqa: C901 (function too complex)
284 """Generate the paths to the selected tests for this operator."""
285 if not self.test_paths:
286 # Exit early when nothing to select from
287 return
288
289 # the test paths that have not been selected yet
290 unused_paths = set(self.test_paths)
291
292 # a list of dictionaries of unused preselected parameter combinations
293 unused_preselected = [x for x in self.preselected]
294 logger.debug(f"preselected: {unused_preselected}")
295
296 # a list of dictionaries of unused permuted parameter combinations
297 permutes = {k: self.params[k] for k in self.permutes}
298 others = {k: self.params[k] for k in self.non_permutes}
299 unused_permuted = [x for x in expand_params(permutes, others)]
300 logger.debug(f"permuted: {unused_permuted}")
301
302 # a dictionary of sets of unused parameter values
303 if self.negative:
304 # We only care about selecting a test for each errorif case
305 unused_values = {k: set() for k in self.params}
306 unused_values["case"] = set(self.params["case"])
307 else:
308 unused_values = {k: set(v) for k, v in self.params.items()}
309
310 # select tests matching permuted, or preselected, parameter combinations
311 for path in self.test_paths:
312 path_params = self.path_params(path)
313 if path_params in unused_permuted or path_params in unused_preselected:
314 unused_paths.remove(path)
315 if path_params in unused_preselected:
316 unused_preselected.remove(path_params)
317 if path_params in unused_permuted:
318 unused_permuted.remove(path_params)
319 if self.negative:
320 # remove any other errorif cases, so we only match one
321 for p in list(unused_permuted):
322 if p["case"] == path_params["case"]:
323 unused_permuted.remove(p)
324 # remove the param values used by this path
325 for k in path_params:
326 unused_values[k].discard(path_params[k])
327 logger.debug(f"FOUND: {path.name}")
328 yield path
329
330 # search for tests that match any unused parameter values
331 for n, path in enumerate(sorted(list(unused_paths))):
332 path_params = self.path_params(path)
333 # select paths with unused param values
334 # skipping some, if sparsity is set for the param
335 for k in path_params:
336 if path_params[k] in unused_values[k] and (
337 k not in self.sparsity or n % self.sparsity[k] == 0
338 ):
339 # remove the param values used by this path
340 for p in path_params:
341 unused_values[p].discard(path_params[p])
342 logger.debug(f"FOUND: {path.name}")
343 yield path
344 break
345
346 # report any preselected combinations that were not found
347 for params in unused_preselected:
348 logger.warning(f"MISSING preselected: {params}")
349 # report any permuted combinations that were not found
350 for params in unused_permuted:
351 logger.debug(f"MISSING permutation: {params}")
352 # report any param values that were not found
353 for k, values in unused_values.items():
354 if values:
355 if k not in self.sparsity:
356 logger.warning(f"MISSING {len(values)} values for {k}: {values}")
357 else:
358 logger.info(
359 f"Skipped {len(values)} values for {k} due to sparsity setting"
360 )
361 logger.debug(f"Values skipped: {values}")
362
363
364class AbsOperator(Operator):
365 """Test selector for the ABS operator."""
366
367 name = "abs"
368
369
370class ArithmeticRightShiftOperator(Operator):
371 """Test selector for the Arithmetic Right Shift operator."""
372
373 name = "arithmetic_right_shift"
374 param_names = ["shape", "type", "rounding"]
375
376
377class AddOperator(Operator):
378 """Test selector for the ADD operator."""
379
380 name = "add"
381
382
383class ArgmaxOperator(Operator):
384 """Test selector for the ARGMAX operator."""
385
386 name = "argmax"
387 param_names = ["shape", "type", "axis"]
388
389
390class AvgPool2dOperator(Operator):
391 """Test selector for the AVG_POOL2D operator."""
392
393 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100394 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100395
396
397class BitwiseAndOperator(Operator):
398 """Test selector for the BITWISE_AND operator."""
399
400 name = "bitwise_and"
401
402
403class BitwiseNotOperator(Operator):
404 """Test selector for the BITWISE_NOT operator."""
405
406 name = "bitwise_not"
407
408
409class BitwiseOrOperator(Operator):
410 """Test selector for the BITWISE_OR operator."""
411
412 name = "bitwise_or"
413
414
415class BitwiseXorOperator(Operator):
416 """Test selector for the BITWISE_XOR operator."""
417
418 name = "bitwise_xor"
419
420
421class CastOperator(Operator):
422 """Test selector for the CAST operator."""
423
424 name = "cast"
425 param_names = ["shape", "type", "output_type"]
426
427
James Ward71616fe2022-11-23 11:00:47 +0000428class CeilOperator(Operator):
429 """Test selector for the CEIL operator."""
430
431 name = "ceil"
432
433
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100434class ClampOperator(Operator):
435 """Test selector for the CLAMP operator."""
436
437 name = "clamp"
438
439
440class CLZOperator(Operator):
441 """Test selector for the CLZ operator."""
442
443 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100444
445
446class ConcatOperator(Operator):
447 """Test selector for the CONCAT operator."""
448
449 name = "concat"
450 param_names = ["shape", "type", "axis"]
451
452
453class CondIfOperator(Operator):
454 """Test selector for the COND_IF operator."""
455
456 name = "cond_if"
457 param_names = ["variant", "shape", "type", "cond"]
458
459
460class ConstOperator(Operator):
461 """Test selector for the CONST operator."""
462
463 name = "const"
464
465
466class Conv2dOperator(Operator):
467 """Test selector for the CONV2D operator."""
468
469 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100470 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100471
472
473class Conv3dOperator(Operator):
474 """Test selector for the CONV3D operator."""
475
476 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100477 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100478
479
480class DepthwiseConv2dOperator(Operator):
481 """Test selector for the DEPTHWISE_CONV2D operator."""
482
483 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100484 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100485
486
487class EqualOperator(Operator):
488 """Test selector for the EQUAL operator."""
489
490 name = "equal"
491
492
Jeremy Johnson35396f22023-01-04 17:05:25 +0000493class ExpOperator(Operator):
494 """Test selector for the EXP operator."""
495
496 name = "exp"
497
498
James Ward71616fe2022-11-23 11:00:47 +0000499class FloorOperator(Operator):
500 """Test selector for the FLOOR operator."""
501
502 name = "floor"
503
504
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100505class FullyConnectedOperator(Operator):
506 """Test selector for the FULLY_CONNECTED operator."""
507
508 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100509 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100510
511
512class GatherOperator(Operator):
513 """Test selector for the GATHER operator."""
514
515 name = "gather"
516
517
518class GreaterOperator(Operator):
519 """Test selector for the GREATER operator."""
520
521 name = "greater"
522
523 @classmethod
524 def get_test_paths(cls, test_dir: Path, negative):
525 """Generate test paths for this operator."""
526 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
527
528
529class GreaterEqualOperator(Operator):
530 """Test selector for the GREATER_EQUAL operator."""
531
532 name = "greater_equal"
533
534
535class IdentityOperator(Operator):
536 """Test selector for the IDENTITY operator."""
537
538 name = "identity"
539
540
541class IntDivOperator(Operator):
Jeremy Johnson35396f22023-01-04 17:05:25 +0000542 """Test selector for the INTDIV operator."""
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100543
544 name = "intdiv"
545
546
Jeremy Johnson35396f22023-01-04 17:05:25 +0000547class LogOperator(Operator):
548 """Test selector for the LOG operator."""
549
550 name = "log"
551
552
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100553class LogicalAndOperator(Operator):
554 """Test selector for the LOGICAL_AND operator."""
555
556 name = "logical_and"
557
558
559class LogicalLeftShiftOperator(Operator):
560 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
561
562 name = "logical_left_shift"
563
564
565class LogicalNotOperator(Operator):
566 """Test selector for the LOGICAL_NOT operator."""
567
568 name = "logical_not"
569
570
571class LogicalOrOperator(Operator):
572 """Test selector for the LOGICAL_OR operator."""
573
574 name = "logical_or"
575
576
577class LogicalRightShiftOperator(Operator):
578 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
579
580 name = "logical_right_shift"
581
582
583class LogicalXorOperator(Operator):
584 """Test selector for the LOGICAL_XOR operator."""
585
586 name = "logical_xor"
587
588
589class MatmulOperator(Operator):
590 """Test selector for the MATMUL operator."""
591
592 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100593 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100594
595
596class MaximumOperator(Operator):
597 """Test selector for the Maximum operator."""
598
599 name = "maximum"
600
601
602class MaxPool2dOperator(Operator):
603 """Test selector for the MAX_POOL2D operator."""
604
605 name = "max_pool2d"
606 param_names = ["shape", "type", "stride", "kernel", "pad"]
607
608
609class MinimumOperator(Operator):
610 """Test selector for the Minimum operator."""
611
612 name = "minimum"
613
614
615class MulOperator(Operator):
616 """Test selector for the MUL operator."""
617
618 name = "mul"
619 param_names = ["shape", "type", "perm", "shift"]
620
621
622class NegateOperator(Operator):
623 """Test selector for the Negate operator."""
624
625 name = "negate"
626
627
628class PadOperator(Operator):
629 """Test selector for the PAD operator."""
630
631 name = "pad"
632 param_names = ["shape", "type", "pad"]
633
634
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000635class PowOperator(Operator):
636 """Test selector for the POW operator."""
637
638 name = "pow"
639
640
Jeremy Johnson35396f22023-01-04 17:05:25 +0000641class ReciprocalOperator(Operator):
642 """Test selector for the RECIPROCAL operator."""
643
644 name = "reciprocal"
645
646
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100647class ReduceAllOperator(Operator):
648 """Test selector for the REDUCE_ALL operator."""
649
650 name = "reduce_all"
651 param_names = ["shape", "type", "axis"]
652
653
654class ReduceAnyOperator(Operator):
655 """Test selector for the REDUCE_ANY operator."""
656
657 name = "reduce_any"
658 param_names = ["shape", "type", "axis"]
659
660
661class ReduceMaxOperator(Operator):
662 """Test selector for the REDUCE_MAX operator."""
663
664 name = "reduce_max"
665 param_names = ["shape", "type", "axis"]
666
667
668class ReduceMinOperator(Operator):
669 """Test selector for the REDUCE_MIN operator."""
670
671 name = "reduce_min"
672 param_names = ["shape", "type", "axis"]
673
674
James Ward512c1ca2023-01-27 18:46:44 +0000675class ReduceProductOperator(Operator):
676 """Test selector for the REDUCE_PRODUCT operator."""
677
678 name = "reduce_product"
679 param_names = ["shape", "type", "axis"]
680
681
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100682class ReduceSumOperator(Operator):
683 """Test selector for the REDUCE_SUM operator."""
684
685 name = "reduce_sum"
686 param_names = ["shape", "type", "axis"]
687
688
689class RescaleOperator(Operator):
690 """Test selector for the RESCALE operator."""
691
692 name = "rescale"
693 param_names = [
694 "shape",
695 "type",
696 "output_type",
697 "scale",
698 "double_round",
699 "per_channel",
700 ]
701
702
703class ReshapeOperator(Operator):
704 """Test selector for the RESHAPE operator."""
705
706 name = "reshape"
707 param_names = ["shape", "type", "perm", "rank"]
708
709
710class ResizeOperator(Operator):
711 """Test selector for the RESIZE operator."""
712
713 name = "resize"
714 param_names = [
715 "shape",
716 "type",
717 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100718 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100719 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100720 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100721 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100722 ]
723
724
725class ReverseOperator(Operator):
726 """Test selector for the REVERSE operator."""
727
728 name = "reverse"
729 param_names = ["shape", "type", "axis"]
730
731
Jeremy Johnson35396f22023-01-04 17:05:25 +0000732class RsqrtOperator(Operator):
733 """Test selector for the RSQRT operator."""
734
735 name = "rsqrt"
736
737
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100738class ScatterOperator(Operator):
739 """Test selector for the SCATTER operator."""
740
741 name = "scatter"
742
743
744class SelectOperator(Operator):
745 """Test selector for the SELECT operator."""
746
747 name = "select"
748
749
James Wardb45db9a2022-12-12 13:02:44 +0000750class SigmoidOperator(Operator):
751 """Test selector for the SIGMOID operator."""
752
753 name = "sigmoid"
754
755
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100756class SliceOperator(Operator):
757 """Test selector for the SLICE operator."""
758
759 name = "slice"
760 param_names = ["shape", "type", "perm"]
761
762
763class SubOperator(Operator):
764 """Test selector for the SUB operator."""
765
766 name = "sub"
767
768
769class TableOperator(Operator):
770 """Test selector for the TABLE operator."""
771
772 name = "table"
773
774
James Wardb45db9a2022-12-12 13:02:44 +0000775class TanhOperator(Operator):
776 """Test selector for the TANH operator."""
777
778 name = "tanh"
779
780
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100781class TileOperator(Operator):
782 """Test selector for the TILE operator."""
783
784 name = "tile"
785 param_names = ["shape", "type", "perm"]
786
787
788class TransposeOperator(Operator):
789 """Test selector for the TRANSPOSE operator."""
790
791 name = "transpose"
792 param_names = ["shape", "type", "perm"]
793
794 @classmethod
795 def get_test_paths(cls, test_dir: Path, negative):
796 """Generate test paths for this operator."""
797 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
798
799
800class TransposeConv2dOperator(Operator):
801 """Test selector for the TRANSPOSE_CONV2D operator."""
802
803 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100804 param_names = [
805 "kernel",
806 "shape",
807 "type",
808 "accum_type",
809 "stride",
810 "pad",
811 "out_shape",
812 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100813
814 def path_params(self, path):
815 """Return a dictionary of params from the test path."""
816 params = super().path_params(path)
817 # out_shape is different for every test case, so ignore it for selection
818 params["out_shape"] = ""
819 return params
820
821
822class WhileLoopOperator(Operator):
823 """Test selector for the WHILE_LOOP operator."""
824
825 name = "while_loop"
826 param_names = ["shape", "type", "cond"]
827
828
829def parse_args():
830 """Parse the arguments."""
831 parser = argparse.ArgumentParser()
832 parser.add_argument(
833 "--test-dir",
834 default=Path.cwd(),
835 type=Path,
836 help=(
837 "The directory where test subdirectories for all operators can be found"
838 " (default: current working directory)"
839 ),
840 )
841 parser.add_argument(
842 "--config",
843 default=Path(__file__).with_suffix(".json"),
844 type=Path,
845 help="A JSON file defining the parameters to use for each operator",
846 )
847 parser.add_argument(
848 "--full-path", action="store_true", help="output the full path for each test"
849 )
850 parser.add_argument(
851 "-v",
852 dest="verbosity",
853 action="count",
854 default=0,
855 help="Verbosity (can be used multiple times for more details)",
856 )
857 parser.add_argument(
858 "operators",
859 type=str,
860 nargs="*",
861 help=(
862 f"Select tests for the specified operator(s)"
863 f" - all operators are assumed if none are specified)"
864 f" - choose from: {[n for n in Operator.registry]}"
865 ),
866 )
867 parser.add_argument(
868 "--test-type",
869 dest="test_type",
870 choices=["positive", "negative"],
871 default="positive",
872 type=str,
873 help="type of tests selected, positive or negative",
874 )
875 return parser.parse_args()
876
877
878def main():
879 """Example test selection."""
880 args = parse_args()
881
882 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +0000883 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100884 logger.info(f"{__file__}: args: {args}")
885
886 try:
887 with open(args.config, "r") as fd:
888 config = json.load(fd)
889 except Exception as e:
890 logger.error(f"Config file error: {e}")
891 return 2
892
893 negative = args.test_type == "negative"
894 for op_name in Operator.registry:
895 if not args.operators or op_name in args.operators:
896 op_params = config[op_name] if op_name in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000897 op = Operator.registry[op_name](args.test_dir, op_params, negative)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100898 for test_path in op.select_tests():
899 print(test_path.resolve() if args.full_path else test_path.name)
900
901 return 0
902
903
904if __name__ == "__main__":
905 exit(main())