blob: 72a4e849430d6969ecc48181790a8ab02a4873cd [file] [log] [blame]
Jeremy Johnson35396f22023-01-04 17:05:25 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01002# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
James Ward736fd1a2023-01-23 17:13:37 +00008import re
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01009from pathlib import Path
10from typing import Any
11from typing import Dict
12from typing import List
13
14logging.basicConfig()
15logger = logging.getLogger("test_select")
16
17
18def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
19 """Generate permuted combinations of a dictionary of values and combine with others.
20
21 permutes: a dictionary with sequences of values to be fully permuted
22 others: a dictionary with sequences of values not fully permuted, but all used
23
24 This yields dictionaries with one value from each of the items in permutes,
25 combined with one value from each of the items in others.
26
27 Example 1:
28
29 permutes = {"a": [1, 2], "b": [3, 4]}
30 others = {"c": [5, 6, 7], "d" [True, False]}
31
32 generates:
33
34 [
35 {"a": 1, "b": 3, "c": 5, "d": True},
36 {"a": 1, "b": 4, "c": 6, "d": False`},
37 {"a": 2, "b": 3, "c": 7, "d": True},
38 {"a": 2, "b": 4, "c": 5, "d": False`},
39 ]
40
41 Example 2:
42
43 permutes = {"a": [1, 2], "b": [3, 4]}
44 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
45
46 generates:
47
48 [
49 {"a": 1, "b": 3, "c": 5, "d": True},
50 {"a": 1, "b": 4, "c": 6, "d": False},
51 {"a": 2, "b": 3, "c": 7, "d": True},
52 {"a": 2, "b": 4, "c": 8, "d": False},
53 {"a": 1, "b": 3, "c": 9, "d": True},
54 ]
55
56 Raises:
57 ValueError if any item is in both permutes and others
58 """
59 for k in permutes:
60 if k in others:
61 raise ValueError(f"item conflict: {k}")
62
63 p_keys = []
64 p_vals = []
65 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
66 # so we return a single, empty dictionary, if others is also empty
67 p_product_len = 1
68 # extract the keys and values from the permutes dictionary
69 # and calulate the product of the number of values in each item as we do so
70 for k, v in permutes.items():
71 p_keys.append(k)
72 p_vals.append(v)
73 p_product_len *= len(v)
74 # create a cyclic generator for the product of all the permuted values
75 p_product = itertools.product(*p_vals)
76 p_generator = itertools.cycle(p_product)
77
78 o_keys = []
79 o_vals = []
80 o_generators = []
81 # extract the keys and values from the others dictionary
82 # and create a cyclic generator for each list of values
83 for k, v in others.items():
84 o_keys.append(k)
85 o_vals.append(v)
86 o_generators.append(itertools.cycle(v))
87
88 # The number of params dictionaries generated will be the maximumum size
89 # of the permuted values and the non-permuted values from others
90 max_items = max([p_product_len] + [len(x) for x in o_vals])
91
92 # create a dictionary with a single value for each of the permutes and others keys
93 for _ in range(max_items):
94 params = {}
95 # add the values for the permutes parameters
96 # the permuted values generator returns a value for each of the permuted keys
97 # in the same order as they were originally given
98 p_vals = next(p_generator)
99 for i in range(len(p_keys)):
100 params[p_keys[i]] = p_vals[i]
101 # add the values for the others parameters
102 # there is a separate generator for each of the others values
103 for i in range(len(o_keys)):
104 params[o_keys[i]] = next(o_generators[i])
105 yield params
106
107
108class Operator:
109 """Base class for operator specific selection properties."""
110
111 # A registry of all Operator subclasses, indexed by the operator name
112 registry = {}
113
114 def __init_subclass__(cls, **kwargs):
115 """Subclass initialiser to register all Operator classes."""
116 super().__init_subclass__(**kwargs)
117 cls.registry[cls.name] = cls
118
119 # Derived classes must override the operator name
120 name = None
121 # Operators with additional parameters must override the param_names
122 # NB: the order must match the order the values appear in the test names
123 param_names = ["shape", "type"]
124
125 # Working set of param_names - updated for negative tests
126 wks_param_names = None
127
128 def __init__(
129 self,
130 test_dir: Path,
131 config: Dict[str, Dict[str, List[Any]]],
132 negative=False,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100133 ):
134 """Initialise the selection parameters for an operator.
135
James Ward736fd1a2023-01-23 17:13:37 +0000136 test_dir: the directory where the tests for all operators can
137 be found
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100138 config: a dictionary with:
James Ward736fd1a2023-01-23 17:13:37 +0000139 "params" - a dictionary with mappings of parameter
140 names to the values to select (a sub-set of
141 expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100142 "permutes" - a list of parameter names to be permuted
James Ward736fd1a2023-01-23 17:13:37 +0000143 "preselected" - a list of dictionaries containing
144 parameter names and pre-chosen values
145 "sparsity" - a dictionary of parameter names with a
146 sparsity value
147 "exclude_patterns" - a list of regex's whereby each
148 match will not be considered for selection.
149 Exclusion happens BEFORE test selection (i.e.
150 before permutes are applied).
151 "errorifs" - list of ERRORIF case names to be selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000152 after exclusion (negative tests)
James Ward736fd1a2023-01-23 17:13:37 +0000153 negative: bool indicating if negative testing is being selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000154 which filters for ERRORIF in the test name and only selects
155 the first test found (ERRORIF tests)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100156
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000157 EXAMPLE CONFIG (with non-json comments):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 "params": {
159 "output_type": [
160 "outi8",
161 "outb"
162 ]
163 },
164 "permutes": [
165 "shape",
166 "type"
167 ],
168 "sparsity": {
169 "pad": 15
170 },
171 "preselected": [
172 {
173 "shape": "6",
174 "type": "i8",
175 "pad": "pad00"
176 }
177 ],
James Ward736fd1a2023-01-23 17:13:37 +0000178 "exclude_patterns": [
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000179 # Exclude positive (not ERRORIF) integer tests
180 "^((?!ERRORIF).)*_(i8|i16|i32|b)_out(i8|i16|i32|b)",
181 # Exclude negative (ERRORIF) i8 test
182 ".*_ERRORIF_.*_i8_outi8"
James Ward736fd1a2023-01-23 17:13:37 +0000183 ],
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100184 "errorifs": [
185 "InputZeroPointNotZero"
186 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100187 """
188 assert isinstance(
189 self.name, str
190 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
191
192 self.negative = negative
193 self.wks_param_names = self.param_names.copy()
194 if self.negative:
195 # need to override positive set up - use "errorifs" config if set
196 # add in errorif case before shape to support all ops, including
197 # different ops like COND_IF and CONVnD etc
198 index = self.wks_param_names.index("shape")
199 self.wks_param_names[index:index] = ["ERRORIF", "case"]
200 config["params"] = {x: [] for x in self.wks_param_names}
201 config["params"]["case"] = (
202 config["errorifs"] if "errorifs" in config else []
203 )
204 config["permutes"] = []
205 config["preselected"] = {}
206
207 self.params = config["params"] if "params" in config else {}
208 self.permutes = config["permutes"] if "permutes" in config else []
209 self.sparsity = config["sparsity"] if "sparsity" in config else {}
210 self.preselected = config["preselected"] if "preselected" in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000211 self.exclude_patterns = (
212 config["exclude_patterns"] if "exclude_patterns" in config else []
213 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100214 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
215 logger.info(f"{self.name}: permutes={self.permutes}")
216 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
James Ward736fd1a2023-01-23 17:13:37 +0000217 logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100218
James Ward736fd1a2023-01-23 17:13:37 +0000219 self.test_paths = []
220 excluded_paths = []
221 for path in self.get_test_paths(test_dir, self.negative):
222 pattern_match = False
223 for pattern in self.exclude_patterns:
224 if re.fullmatch(pattern, path.name):
225 excluded_paths.append(path)
226 pattern_match = True
227 break
228 if not pattern_match:
229 self.test_paths.append(path)
230
231 logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
232
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100233 if not self.test_paths:
234 logger.error(f"no tests found for {self.name} in {test_dir}")
235 logger.debug(f"{self.name}: paths={self.test_paths}")
236
237 # get default parameter values for any not given in the config
238 default_params = self.get_default_params()
239 for param in default_params:
240 if param not in self.params or not self.params[param]:
241 self.params[param] = default_params[param]
242 for param in self.wks_param_names:
243 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
244
245 @staticmethod
246 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
247 """Generate test paths for operators using operator specifics."""
248 for base_dir in sorted(test_dir.glob(base_dir_glob)):
249 for path in sorted(base_dir.glob(path_glob)):
250 if (not negative and "ERRORIF" not in str(path)) or (
251 negative and "ERRORIF" in str(path)
252 ):
253 yield path
254
255 @classmethod
256 def get_test_paths(cls, test_dir: Path, negative):
257 """Generate test paths for this operator."""
258 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
259
260 def path_params(self, path):
261 """Return a dictionary of params from the test path."""
262 params = {}
263 op_name_parts = self.name.split("_")
264 values = path.name.split("_")[len(op_name_parts) :]
265 assert len(values) == len(
266 self.wks_param_names
267 ), f"len({values}) == len({self.wks_param_names})"
268 for i, param in enumerate(self.wks_param_names):
269 params[param] = values[i]
270 return params
271
272 def get_default_params(self):
273 """Get the default parameter values from the test names."""
274 params = {param: set() for param in self.wks_param_names}
275 for path in self.test_paths:
276 path_params = self.path_params(path)
277 for k in params:
278 params[k].add(path_params[k])
279 for param in params:
280 params[param] = sorted(list(params[param]))
281 return params
282
283 def select_tests(self): # noqa: C901 (function too complex)
284 """Generate the paths to the selected tests for this operator."""
285 if not self.test_paths:
286 # Exit early when nothing to select from
287 return
288
289 # the test paths that have not been selected yet
290 unused_paths = set(self.test_paths)
291
292 # a list of dictionaries of unused preselected parameter combinations
293 unused_preselected = [x for x in self.preselected]
294 logger.debug(f"preselected: {unused_preselected}")
295
296 # a list of dictionaries of unused permuted parameter combinations
297 permutes = {k: self.params[k] for k in self.permutes}
298 others = {k: self.params[k] for k in self.non_permutes}
299 unused_permuted = [x for x in expand_params(permutes, others)]
300 logger.debug(f"permuted: {unused_permuted}")
301
302 # a dictionary of sets of unused parameter values
303 if self.negative:
304 # We only care about selecting a test for each errorif case
305 unused_values = {k: set() for k in self.params}
306 unused_values["case"] = set(self.params["case"])
307 else:
308 unused_values = {k: set(v) for k, v in self.params.items()}
309
310 # select tests matching permuted, or preselected, parameter combinations
311 for path in self.test_paths:
312 path_params = self.path_params(path)
313 if path_params in unused_permuted or path_params in unused_preselected:
314 unused_paths.remove(path)
315 if path_params in unused_preselected:
316 unused_preselected.remove(path_params)
317 if path_params in unused_permuted:
318 unused_permuted.remove(path_params)
319 if self.negative:
320 # remove any other errorif cases, so we only match one
321 for p in list(unused_permuted):
322 if p["case"] == path_params["case"]:
323 unused_permuted.remove(p)
324 # remove the param values used by this path
325 for k in path_params:
326 unused_values[k].discard(path_params[k])
327 logger.debug(f"FOUND: {path.name}")
328 yield path
329
330 # search for tests that match any unused parameter values
331 for n, path in enumerate(sorted(list(unused_paths))):
332 path_params = self.path_params(path)
333 # select paths with unused param values
334 # skipping some, if sparsity is set for the param
335 for k in path_params:
336 if path_params[k] in unused_values[k] and (
337 k not in self.sparsity or n % self.sparsity[k] == 0
338 ):
339 # remove the param values used by this path
340 for p in path_params:
341 unused_values[p].discard(path_params[p])
342 logger.debug(f"FOUND: {path.name}")
343 yield path
344 break
345
346 # report any preselected combinations that were not found
347 for params in unused_preselected:
348 logger.warning(f"MISSING preselected: {params}")
349 # report any permuted combinations that were not found
350 for params in unused_permuted:
351 logger.debug(f"MISSING permutation: {params}")
352 # report any param values that were not found
353 for k, values in unused_values.items():
354 if values:
355 if k not in self.sparsity:
356 logger.warning(f"MISSING {len(values)} values for {k}: {values}")
357 else:
358 logger.info(
359 f"Skipped {len(values)} values for {k} due to sparsity setting"
360 )
361 logger.debug(f"Values skipped: {values}")
362
363
364class AbsOperator(Operator):
365 """Test selector for the ABS operator."""
366
367 name = "abs"
368
369
370class ArithmeticRightShiftOperator(Operator):
371 """Test selector for the Arithmetic Right Shift operator."""
372
373 name = "arithmetic_right_shift"
374 param_names = ["shape", "type", "rounding"]
375
376
377class AddOperator(Operator):
378 """Test selector for the ADD operator."""
379
380 name = "add"
381
382
383class ArgmaxOperator(Operator):
384 """Test selector for the ARGMAX operator."""
385
386 name = "argmax"
387 param_names = ["shape", "type", "axis"]
388
389
390class AvgPool2dOperator(Operator):
391 """Test selector for the AVG_POOL2D operator."""
392
393 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100394 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100395
396
397class BitwiseAndOperator(Operator):
398 """Test selector for the BITWISE_AND operator."""
399
400 name = "bitwise_and"
401
402
403class BitwiseNotOperator(Operator):
404 """Test selector for the BITWISE_NOT operator."""
405
406 name = "bitwise_not"
407
408
409class BitwiseOrOperator(Operator):
410 """Test selector for the BITWISE_OR operator."""
411
412 name = "bitwise_or"
413
414
415class BitwiseXorOperator(Operator):
416 """Test selector for the BITWISE_XOR operator."""
417
418 name = "bitwise_xor"
419
420
421class CastOperator(Operator):
422 """Test selector for the CAST operator."""
423
424 name = "cast"
425 param_names = ["shape", "type", "output_type"]
426
427
James Ward71616fe2022-11-23 11:00:47 +0000428class CeilOperator(Operator):
429 """Test selector for the CEIL operator."""
430
431 name = "ceil"
432
433
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100434class ClampOperator(Operator):
435 """Test selector for the CLAMP operator."""
436
437 name = "clamp"
438
439
440class CLZOperator(Operator):
441 """Test selector for the CLZ operator."""
442
443 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100444
445
446class ConcatOperator(Operator):
447 """Test selector for the CONCAT operator."""
448
449 name = "concat"
450 param_names = ["shape", "type", "axis"]
451
452
453class CondIfOperator(Operator):
454 """Test selector for the COND_IF operator."""
455
456 name = "cond_if"
457 param_names = ["variant", "shape", "type", "cond"]
458
459
460class ConstOperator(Operator):
461 """Test selector for the CONST operator."""
462
463 name = "const"
464
465
466class Conv2dOperator(Operator):
467 """Test selector for the CONV2D operator."""
468
469 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100470 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100471
472
473class Conv3dOperator(Operator):
474 """Test selector for the CONV3D operator."""
475
476 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100477 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100478
479
480class DepthwiseConv2dOperator(Operator):
481 """Test selector for the DEPTHWISE_CONV2D operator."""
482
483 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100484 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100485
486
487class EqualOperator(Operator):
488 """Test selector for the EQUAL operator."""
489
490 name = "equal"
491
492
Jeremy Johnson35396f22023-01-04 17:05:25 +0000493class ExpOperator(Operator):
494 """Test selector for the EXP operator."""
495
496 name = "exp"
497
498
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000499class FFT2DOperator(Operator):
500 """Test selector for the FFT2D operator."""
501
502 name = "fft2d"
503 param_names = ["shape", "type", "inverse"]
504
505
James Ward71616fe2022-11-23 11:00:47 +0000506class FloorOperator(Operator):
507 """Test selector for the FLOOR operator."""
508
509 name = "floor"
510
511
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100512class FullyConnectedOperator(Operator):
513 """Test selector for the FULLY_CONNECTED operator."""
514
515 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100516 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100517
518
519class GatherOperator(Operator):
520 """Test selector for the GATHER operator."""
521
522 name = "gather"
523
524
525class GreaterOperator(Operator):
526 """Test selector for the GREATER operator."""
527
528 name = "greater"
529
530 @classmethod
531 def get_test_paths(cls, test_dir: Path, negative):
532 """Generate test paths for this operator."""
533 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
534
535
536class GreaterEqualOperator(Operator):
537 """Test selector for the GREATER_EQUAL operator."""
538
539 name = "greater_equal"
540
541
542class IdentityOperator(Operator):
543 """Test selector for the IDENTITY operator."""
544
545 name = "identity"
546
547
548class IntDivOperator(Operator):
Jeremy Johnson35396f22023-01-04 17:05:25 +0000549 """Test selector for the INTDIV operator."""
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100550
551 name = "intdiv"
552
553
Jeremy Johnson35396f22023-01-04 17:05:25 +0000554class LogOperator(Operator):
555 """Test selector for the LOG operator."""
556
557 name = "log"
558
559
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100560class LogicalAndOperator(Operator):
561 """Test selector for the LOGICAL_AND operator."""
562
563 name = "logical_and"
564
565
566class LogicalLeftShiftOperator(Operator):
567 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
568
569 name = "logical_left_shift"
570
571
572class LogicalNotOperator(Operator):
573 """Test selector for the LOGICAL_NOT operator."""
574
575 name = "logical_not"
576
577
578class LogicalOrOperator(Operator):
579 """Test selector for the LOGICAL_OR operator."""
580
581 name = "logical_or"
582
583
584class LogicalRightShiftOperator(Operator):
585 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
586
587 name = "logical_right_shift"
588
589
590class LogicalXorOperator(Operator):
591 """Test selector for the LOGICAL_XOR operator."""
592
593 name = "logical_xor"
594
595
596class MatmulOperator(Operator):
597 """Test selector for the MATMUL operator."""
598
599 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100600 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100601
602
603class MaximumOperator(Operator):
604 """Test selector for the Maximum operator."""
605
606 name = "maximum"
607
608
609class MaxPool2dOperator(Operator):
610 """Test selector for the MAX_POOL2D operator."""
611
612 name = "max_pool2d"
613 param_names = ["shape", "type", "stride", "kernel", "pad"]
614
615
616class MinimumOperator(Operator):
617 """Test selector for the Minimum operator."""
618
619 name = "minimum"
620
621
622class MulOperator(Operator):
623 """Test selector for the MUL operator."""
624
625 name = "mul"
626 param_names = ["shape", "type", "perm", "shift"]
627
628
629class NegateOperator(Operator):
630 """Test selector for the Negate operator."""
631
632 name = "negate"
633
634
635class PadOperator(Operator):
636 """Test selector for the PAD operator."""
637
638 name = "pad"
639 param_names = ["shape", "type", "pad"]
640
641
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000642class PowOperator(Operator):
643 """Test selector for the POW operator."""
644
645 name = "pow"
646
647
Jeremy Johnson35396f22023-01-04 17:05:25 +0000648class ReciprocalOperator(Operator):
649 """Test selector for the RECIPROCAL operator."""
650
651 name = "reciprocal"
652
653
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100654class ReduceAllOperator(Operator):
655 """Test selector for the REDUCE_ALL operator."""
656
657 name = "reduce_all"
658 param_names = ["shape", "type", "axis"]
659
660
661class ReduceAnyOperator(Operator):
662 """Test selector for the REDUCE_ANY operator."""
663
664 name = "reduce_any"
665 param_names = ["shape", "type", "axis"]
666
667
668class ReduceMaxOperator(Operator):
669 """Test selector for the REDUCE_MAX operator."""
670
671 name = "reduce_max"
672 param_names = ["shape", "type", "axis"]
673
674
675class ReduceMinOperator(Operator):
676 """Test selector for the REDUCE_MIN operator."""
677
678 name = "reduce_min"
679 param_names = ["shape", "type", "axis"]
680
681
James Ward512c1ca2023-01-27 18:46:44 +0000682class ReduceProductOperator(Operator):
683 """Test selector for the REDUCE_PRODUCT operator."""
684
685 name = "reduce_product"
686 param_names = ["shape", "type", "axis"]
687
688
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100689class ReduceSumOperator(Operator):
690 """Test selector for the REDUCE_SUM operator."""
691
692 name = "reduce_sum"
693 param_names = ["shape", "type", "axis"]
694
695
696class RescaleOperator(Operator):
697 """Test selector for the RESCALE operator."""
698
699 name = "rescale"
700 param_names = [
701 "shape",
702 "type",
703 "output_type",
704 "scale",
705 "double_round",
706 "per_channel",
707 ]
708
709
710class ReshapeOperator(Operator):
711 """Test selector for the RESHAPE operator."""
712
713 name = "reshape"
714 param_names = ["shape", "type", "perm", "rank"]
715
716
717class ResizeOperator(Operator):
718 """Test selector for the RESIZE operator."""
719
720 name = "resize"
721 param_names = [
722 "shape",
723 "type",
724 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100725 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100726 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100727 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100728 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100729 ]
730
731
732class ReverseOperator(Operator):
733 """Test selector for the REVERSE operator."""
734
735 name = "reverse"
736 param_names = ["shape", "type", "axis"]
737
738
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000739class RFFT2DOperator(Operator):
740 """Test selector for the RFFT2D operator."""
741
742 name = "rfft2d"
743
744
Jeremy Johnson35396f22023-01-04 17:05:25 +0000745class RsqrtOperator(Operator):
746 """Test selector for the RSQRT operator."""
747
748 name = "rsqrt"
749
750
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100751class ScatterOperator(Operator):
752 """Test selector for the SCATTER operator."""
753
754 name = "scatter"
755
756
757class SelectOperator(Operator):
758 """Test selector for the SELECT operator."""
759
760 name = "select"
761
762
James Wardb45db9a2022-12-12 13:02:44 +0000763class SigmoidOperator(Operator):
764 """Test selector for the SIGMOID operator."""
765
766 name = "sigmoid"
767
768
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100769class SliceOperator(Operator):
770 """Test selector for the SLICE operator."""
771
772 name = "slice"
773 param_names = ["shape", "type", "perm"]
774
775
776class SubOperator(Operator):
777 """Test selector for the SUB operator."""
778
779 name = "sub"
780
781
782class TableOperator(Operator):
783 """Test selector for the TABLE operator."""
784
785 name = "table"
786
787
James Wardb45db9a2022-12-12 13:02:44 +0000788class TanhOperator(Operator):
789 """Test selector for the TANH operator."""
790
791 name = "tanh"
792
793
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100794class TileOperator(Operator):
795 """Test selector for the TILE operator."""
796
797 name = "tile"
798 param_names = ["shape", "type", "perm"]
799
800
801class TransposeOperator(Operator):
802 """Test selector for the TRANSPOSE operator."""
803
804 name = "transpose"
805 param_names = ["shape", "type", "perm"]
806
807 @classmethod
808 def get_test_paths(cls, test_dir: Path, negative):
809 """Generate test paths for this operator."""
810 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
811
812
813class TransposeConv2dOperator(Operator):
814 """Test selector for the TRANSPOSE_CONV2D operator."""
815
816 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100817 param_names = [
818 "kernel",
819 "shape",
820 "type",
821 "accum_type",
822 "stride",
823 "pad",
824 "out_shape",
825 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100826
827 def path_params(self, path):
828 """Return a dictionary of params from the test path."""
829 params = super().path_params(path)
830 # out_shape is different for every test case, so ignore it for selection
831 params["out_shape"] = ""
832 return params
833
834
835class WhileLoopOperator(Operator):
836 """Test selector for the WHILE_LOOP operator."""
837
838 name = "while_loop"
839 param_names = ["shape", "type", "cond"]
840
841
842def parse_args():
843 """Parse the arguments."""
844 parser = argparse.ArgumentParser()
845 parser.add_argument(
846 "--test-dir",
847 default=Path.cwd(),
848 type=Path,
849 help=(
850 "The directory where test subdirectories for all operators can be found"
851 " (default: current working directory)"
852 ),
853 )
854 parser.add_argument(
855 "--config",
856 default=Path(__file__).with_suffix(".json"),
857 type=Path,
858 help="A JSON file defining the parameters to use for each operator",
859 )
860 parser.add_argument(
861 "--full-path", action="store_true", help="output the full path for each test"
862 )
863 parser.add_argument(
864 "-v",
865 dest="verbosity",
866 action="count",
867 default=0,
868 help="Verbosity (can be used multiple times for more details)",
869 )
870 parser.add_argument(
871 "operators",
872 type=str,
873 nargs="*",
874 help=(
875 f"Select tests for the specified operator(s)"
876 f" - all operators are assumed if none are specified)"
877 f" - choose from: {[n for n in Operator.registry]}"
878 ),
879 )
880 parser.add_argument(
881 "--test-type",
882 dest="test_type",
883 choices=["positive", "negative"],
884 default="positive",
885 type=str,
886 help="type of tests selected, positive or negative",
887 )
888 return parser.parse_args()
889
890
891def main():
892 """Example test selection."""
893 args = parse_args()
894
895 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +0000896 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100897 logger.info(f"{__file__}: args: {args}")
898
899 try:
900 with open(args.config, "r") as fd:
901 config = json.load(fd)
902 except Exception as e:
903 logger.error(f"Config file error: {e}")
904 return 2
905
906 negative = args.test_type == "negative"
907 for op_name in Operator.registry:
908 if not args.operators or op_name in args.operators:
909 op_params = config[op_name] if op_name in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000910 op = Operator.registry[op_name](args.test_dir, op_params, negative)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100911 for test_path in op.select_tests():
912 print(test_path.resolve() if args.full_path else test_path.name)
913
914 return 0
915
916
917if __name__ == "__main__":
918 exit(main())