blob: faefc8573b490ed351d2d4427e76538c30c0bce4 [file] [log] [blame]
Jeremy Johnson35396f22023-01-04 17:05:25 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01002# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
James Ward736fd1a2023-01-23 17:13:37 +00008import re
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01009from pathlib import Path
10from typing import Any
11from typing import Dict
12from typing import List
13
14logging.basicConfig()
15logger = logging.getLogger("test_select")
16
17
18def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
19 """Generate permuted combinations of a dictionary of values and combine with others.
20
21 permutes: a dictionary with sequences of values to be fully permuted
22 others: a dictionary with sequences of values not fully permuted, but all used
23
24 This yields dictionaries with one value from each of the items in permutes,
25 combined with one value from each of the items in others.
26
27 Example 1:
28
29 permutes = {"a": [1, 2], "b": [3, 4]}
30 others = {"c": [5, 6, 7], "d" [True, False]}
31
32 generates:
33
34 [
35 {"a": 1, "b": 3, "c": 5, "d": True},
36 {"a": 1, "b": 4, "c": 6, "d": False`},
37 {"a": 2, "b": 3, "c": 7, "d": True},
38 {"a": 2, "b": 4, "c": 5, "d": False`},
39 ]
40
41 Example 2:
42
43 permutes = {"a": [1, 2], "b": [3, 4]}
44 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
45
46 generates:
47
48 [
49 {"a": 1, "b": 3, "c": 5, "d": True},
50 {"a": 1, "b": 4, "c": 6, "d": False},
51 {"a": 2, "b": 3, "c": 7, "d": True},
52 {"a": 2, "b": 4, "c": 8, "d": False},
53 {"a": 1, "b": 3, "c": 9, "d": True},
54 ]
55
56 Raises:
57 ValueError if any item is in both permutes and others
58 """
59 for k in permutes:
60 if k in others:
61 raise ValueError(f"item conflict: {k}")
62
63 p_keys = []
64 p_vals = []
65 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
66 # so we return a single, empty dictionary, if others is also empty
67 p_product_len = 1
68 # extract the keys and values from the permutes dictionary
69 # and calulate the product of the number of values in each item as we do so
70 for k, v in permutes.items():
71 p_keys.append(k)
72 p_vals.append(v)
73 p_product_len *= len(v)
74 # create a cyclic generator for the product of all the permuted values
75 p_product = itertools.product(*p_vals)
76 p_generator = itertools.cycle(p_product)
77
78 o_keys = []
79 o_vals = []
80 o_generators = []
81 # extract the keys and values from the others dictionary
82 # and create a cyclic generator for each list of values
83 for k, v in others.items():
84 o_keys.append(k)
85 o_vals.append(v)
86 o_generators.append(itertools.cycle(v))
87
88 # The number of params dictionaries generated will be the maximumum size
89 # of the permuted values and the non-permuted values from others
90 max_items = max([p_product_len] + [len(x) for x in o_vals])
91
92 # create a dictionary with a single value for each of the permutes and others keys
93 for _ in range(max_items):
94 params = {}
95 # add the values for the permutes parameters
96 # the permuted values generator returns a value for each of the permuted keys
97 # in the same order as they were originally given
98 p_vals = next(p_generator)
99 for i in range(len(p_keys)):
100 params[p_keys[i]] = p_vals[i]
101 # add the values for the others parameters
102 # there is a separate generator for each of the others values
103 for i in range(len(o_keys)):
104 params[o_keys[i]] = next(o_generators[i])
105 yield params
106
107
108class Operator:
109 """Base class for operator specific selection properties."""
110
111 # A registry of all Operator subclasses, indexed by the operator name
112 registry = {}
113
114 def __init_subclass__(cls, **kwargs):
115 """Subclass initialiser to register all Operator classes."""
116 super().__init_subclass__(**kwargs)
117 cls.registry[cls.name] = cls
118
119 # Derived classes must override the operator name
120 name = None
121 # Operators with additional parameters must override the param_names
122 # NB: the order must match the order the values appear in the test names
123 param_names = ["shape", "type"]
124
125 # Working set of param_names - updated for negative tests
126 wks_param_names = None
127
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100128 COMPLIANCE_SETS = ("_s0", "_s1", "_s2", "_s3", "_s4", "_s5")
129
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100130 def __init__(
131 self,
132 test_dir: Path,
133 config: Dict[str, Dict[str, List[Any]]],
134 negative=False,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000135 ignore_missing=False,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100136 ):
137 """Initialise the selection parameters for an operator.
138
James Ward736fd1a2023-01-23 17:13:37 +0000139 test_dir: the directory where the tests for all operators can
140 be found
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100141 config: a dictionary with:
James Ward736fd1a2023-01-23 17:13:37 +0000142 "params" - a dictionary with mappings of parameter
143 names to the values to select (a sub-set of
144 expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100145 "permutes" - a list of parameter names to be permuted
James Ward736fd1a2023-01-23 17:13:37 +0000146 "preselected" - a list of dictionaries containing
147 parameter names and pre-chosen values
148 "sparsity" - a dictionary of parameter names with a
149 sparsity value
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000150 "full_sparsity" - "true"/"false" to use the sparsity
151 value on permutes/params/preselected
James Ward736fd1a2023-01-23 17:13:37 +0000152 "exclude_patterns" - a list of regex's whereby each
153 match will not be considered for selection.
154 Exclusion happens BEFORE test selection (i.e.
155 before permutes are applied).
156 "errorifs" - list of ERRORIF case names to be selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000157 after exclusion (negative tests)
James Ward736fd1a2023-01-23 17:13:37 +0000158 negative: bool indicating if negative testing is being selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000159 which filters for ERRORIF in the test name and only selects
160 the first test found (ERRORIF tests)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000161 ignore_missing: bool indicating if missing tests should be ignored
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000163 EXAMPLE CONFIG (with non-json comments):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100164 "params": {
165 "output_type": [
166 "outi8",
167 "outb"
168 ]
169 },
170 "permutes": [
171 "shape",
172 "type"
173 ],
174 "sparsity": {
175 "pad": 15
176 },
177 "preselected": [
178 {
179 "shape": "6",
180 "type": "i8",
181 "pad": "pad00"
182 }
183 ],
James Ward736fd1a2023-01-23 17:13:37 +0000184 "exclude_patterns": [
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000185 # Exclude positive (not ERRORIF) integer tests
186 "^((?!ERRORIF).)*_(i8|i16|i32|b)_out(i8|i16|i32|b)",
187 # Exclude negative (ERRORIF) i8 test
188 ".*_ERRORIF_.*_i8_outi8"
James Ward736fd1a2023-01-23 17:13:37 +0000189 ],
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100190 "errorifs": [
191 "InputZeroPointNotZero"
192 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100193 """
194 assert isinstance(
195 self.name, str
196 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
197
198 self.negative = negative
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000199 self.ignore_missing = ignore_missing
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100200 self.wks_param_names = self.param_names.copy()
201 if self.negative:
202 # need to override positive set up - use "errorifs" config if set
203 # add in errorif case before shape to support all ops, including
204 # different ops like COND_IF and CONVnD etc
205 index = self.wks_param_names.index("shape")
206 self.wks_param_names[index:index] = ["ERRORIF", "case"]
207 config["params"] = {x: [] for x in self.wks_param_names}
208 config["params"]["case"] = (
209 config["errorifs"] if "errorifs" in config else []
210 )
211 config["permutes"] = []
212 config["preselected"] = {}
213
214 self.params = config["params"] if "params" in config else {}
215 self.permutes = config["permutes"] if "permutes" in config else []
216 self.sparsity = config["sparsity"] if "sparsity" in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000217 self.full_sparsity = (
218 (config["full_sparsity"] == "true") if "full_sparsity" in config else False
219 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100220 self.preselected = config["preselected"] if "preselected" in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000221 self.exclude_patterns = (
222 config["exclude_patterns"] if "exclude_patterns" in config else []
223 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100224 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
225 logger.info(f"{self.name}: permutes={self.permutes}")
226 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
James Ward736fd1a2023-01-23 17:13:37 +0000227 logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100228
James Ward736fd1a2023-01-23 17:13:37 +0000229 self.test_paths = []
230 excluded_paths = []
231 for path in self.get_test_paths(test_dir, self.negative):
232 pattern_match = False
233 for pattern in self.exclude_patterns:
234 if re.fullmatch(pattern, path.name):
235 excluded_paths.append(path)
236 pattern_match = True
237 break
238 if not pattern_match:
239 self.test_paths.append(path)
240
241 logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
242
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100243 if not self.test_paths:
244 logger.error(f"no tests found for {self.name} in {test_dir}")
245 logger.debug(f"{self.name}: paths={self.test_paths}")
246
247 # get default parameter values for any not given in the config
248 default_params = self.get_default_params()
249 for param in default_params:
250 if param not in self.params or not self.params[param]:
251 self.params[param] = default_params[param]
252 for param in self.wks_param_names:
253 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
254
255 @staticmethod
256 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
257 """Generate test paths for operators using operator specifics."""
258 for base_dir in sorted(test_dir.glob(base_dir_glob)):
259 for path in sorted(base_dir.glob(path_glob)):
260 if (not negative and "ERRORIF" not in str(path)) or (
261 negative and "ERRORIF" in str(path)
262 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100263 # Check for compliance test set paths
264 suffix = path.name[-3:]
265 if suffix in Operator.COMPLIANCE_SETS:
266 if suffix != Operator.COMPLIANCE_SETS[0]:
267 # Only return one of the test sets
268 continue
269 yield path.with_name(path.name[:-3])
270 else:
271 yield path
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100272
273 @classmethod
274 def get_test_paths(cls, test_dir: Path, negative):
275 """Generate test paths for this operator."""
276 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
277
278 def path_params(self, path):
279 """Return a dictionary of params from the test path."""
280 params = {}
281 op_name_parts = self.name.split("_")
282 values = path.name.split("_")[len(op_name_parts) :]
283 assert len(values) == len(
284 self.wks_param_names
285 ), f"len({values}) == len({self.wks_param_names})"
286 for i, param in enumerate(self.wks_param_names):
287 params[param] = values[i]
288 return params
289
290 def get_default_params(self):
291 """Get the default parameter values from the test names."""
292 params = {param: set() for param in self.wks_param_names}
293 for path in self.test_paths:
294 path_params = self.path_params(path)
295 for k in params:
296 params[k].add(path_params[k])
297 for param in params:
298 params[param] = sorted(list(params[param]))
299 return params
300
301 def select_tests(self): # noqa: C901 (function too complex)
302 """Generate the paths to the selected tests for this operator."""
303 if not self.test_paths:
304 # Exit early when nothing to select from
305 return
306
307 # the test paths that have not been selected yet
308 unused_paths = set(self.test_paths)
309
310 # a list of dictionaries of unused preselected parameter combinations
311 unused_preselected = [x for x in self.preselected]
312 logger.debug(f"preselected: {unused_preselected}")
313
314 # a list of dictionaries of unused permuted parameter combinations
315 permutes = {k: self.params[k] for k in self.permutes}
316 others = {k: self.params[k] for k in self.non_permutes}
317 unused_permuted = [x for x in expand_params(permutes, others)]
318 logger.debug(f"permuted: {unused_permuted}")
319
320 # a dictionary of sets of unused parameter values
321 if self.negative:
322 # We only care about selecting a test for each errorif case
323 unused_values = {k: set() for k in self.params}
324 unused_values["case"] = set(self.params["case"])
325 else:
326 unused_values = {k: set(v) for k, v in self.params.items()}
327
328 # select tests matching permuted, or preselected, parameter combinations
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000329 for n, path in enumerate(self.test_paths):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100330 path_params = self.path_params(path)
331 if path_params in unused_permuted or path_params in unused_preselected:
332 unused_paths.remove(path)
333 if path_params in unused_preselected:
334 unused_preselected.remove(path_params)
335 if path_params in unused_permuted:
336 unused_permuted.remove(path_params)
337 if self.negative:
338 # remove any other errorif cases, so we only match one
339 for p in list(unused_permuted):
340 if p["case"] == path_params["case"]:
341 unused_permuted.remove(p)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000342 if self.full_sparsity:
343 # Test for sparsity
344 skip = False
345 for k in path_params:
346 if k in self.sparsity and n % self.sparsity[k] != 0:
347 logger.debug(f"Skipping due to {k} sparsity - {path.name}")
348 skip = True
349 break
350 if skip:
351 continue
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100352 # remove the param values used by this path
353 for k in path_params:
354 unused_values[k].discard(path_params[k])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000355 logger.debug(f"FOUND wanted: {path.name}")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100356 if path.exists():
357 yield path
358 else:
359 # Compliance test series - expand to all sets
360 for s in Operator.COMPLIANCE_SETS:
361 yield path.with_name(f"{path.name}{s}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100362
363 # search for tests that match any unused parameter values
364 for n, path in enumerate(sorted(list(unused_paths))):
365 path_params = self.path_params(path)
366 # select paths with unused param values
367 # skipping some, if sparsity is set for the param
368 for k in path_params:
369 if path_params[k] in unused_values[k] and (
370 k not in self.sparsity or n % self.sparsity[k] == 0
371 ):
372 # remove the param values used by this path
373 for p in path_params:
374 unused_values[p].discard(path_params[p])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000375 sparsity = self.sparsity[k] if k in self.sparsity else 0
376 logger.debug(f"FOUND unused [{k}/{n}/{sparsity}]: {path.name}")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100377 if path.exists():
378 yield path
379 else:
380 # Compliance test series - expand to all sets
381 for s in Operator.COMPLIANCE_SETS:
382 yield path.with_name(f"{path.name}{s}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100383 break
384
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000385 if not self.ignore_missing:
386 # report any preselected combinations that were not found
387 for params in unused_preselected:
388 logger.warning(f"MISSING preselected: {params}")
389 # report any permuted combinations that were not found
390 for params in unused_permuted:
391 logger.debug(f"MISSING permutation: {params}")
392 # report any param values that were not found
393 for k, values in unused_values.items():
394 if values:
395 if k not in self.sparsity:
396 logger.warning(
397 f"MISSING {len(values)} values for {k}: {values}"
398 )
399 else:
400 logger.info(
401 f"Skipped {len(values)} values for {k} due to sparsity setting"
402 )
403 logger.debug(f"Values skipped: {values}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100404
405
406class AbsOperator(Operator):
407 """Test selector for the ABS operator."""
408
409 name = "abs"
410
411
412class ArithmeticRightShiftOperator(Operator):
413 """Test selector for the Arithmetic Right Shift operator."""
414
415 name = "arithmetic_right_shift"
416 param_names = ["shape", "type", "rounding"]
417
418
419class AddOperator(Operator):
420 """Test selector for the ADD operator."""
421
422 name = "add"
423
424
425class ArgmaxOperator(Operator):
426 """Test selector for the ARGMAX operator."""
427
428 name = "argmax"
429 param_names = ["shape", "type", "axis"]
430
431
432class AvgPool2dOperator(Operator):
433 """Test selector for the AVG_POOL2D operator."""
434
435 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100436 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100437
438
439class BitwiseAndOperator(Operator):
440 """Test selector for the BITWISE_AND operator."""
441
442 name = "bitwise_and"
443
444
445class BitwiseNotOperator(Operator):
446 """Test selector for the BITWISE_NOT operator."""
447
448 name = "bitwise_not"
449
450
451class BitwiseOrOperator(Operator):
452 """Test selector for the BITWISE_OR operator."""
453
454 name = "bitwise_or"
455
456
457class BitwiseXorOperator(Operator):
458 """Test selector for the BITWISE_XOR operator."""
459
460 name = "bitwise_xor"
461
462
463class CastOperator(Operator):
464 """Test selector for the CAST operator."""
465
466 name = "cast"
467 param_names = ["shape", "type", "output_type"]
468
469
James Ward71616fe2022-11-23 11:00:47 +0000470class CeilOperator(Operator):
471 """Test selector for the CEIL operator."""
472
473 name = "ceil"
474
475
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100476class ClampOperator(Operator):
477 """Test selector for the CLAMP operator."""
478
479 name = "clamp"
480
481
482class CLZOperator(Operator):
483 """Test selector for the CLZ operator."""
484
485 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100486
487
488class ConcatOperator(Operator):
489 """Test selector for the CONCAT operator."""
490
491 name = "concat"
492 param_names = ["shape", "type", "axis"]
493
494
495class CondIfOperator(Operator):
496 """Test selector for the COND_IF operator."""
497
498 name = "cond_if"
499 param_names = ["variant", "shape", "type", "cond"]
500
501
502class ConstOperator(Operator):
503 """Test selector for the CONST operator."""
504
505 name = "const"
506
507
508class Conv2dOperator(Operator):
509 """Test selector for the CONV2D operator."""
510
511 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100512 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100513
514
515class Conv3dOperator(Operator):
516 """Test selector for the CONV3D operator."""
517
518 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100519 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100520
521
522class DepthwiseConv2dOperator(Operator):
523 """Test selector for the DEPTHWISE_CONV2D operator."""
524
525 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100526 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100527
528
Won Jeona21b2e82023-08-10 10:33:01 +0000529class DimOeprator(Operator):
530 """Test selector for the DIM operator."""
531
532 name = "dim"
533 param_names = ["shape", "type", "axis"]
534
535
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100536class EqualOperator(Operator):
537 """Test selector for the EQUAL operator."""
538
539 name = "equal"
540
541
Jeremy Johnson35396f22023-01-04 17:05:25 +0000542class ExpOperator(Operator):
543 """Test selector for the EXP operator."""
544
545 name = "exp"
546
547
Won Jeon78155c62023-06-10 00:20:04 +0000548class ErfOperator(Operator):
549 """Test selector for the ERF operator."""
550
551 name = "erf"
552
553
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000554class FFT2DOperator(Operator):
555 """Test selector for the FFT2D operator."""
556
557 name = "fft2d"
558 param_names = ["shape", "type", "inverse"]
559
560
James Ward71616fe2022-11-23 11:00:47 +0000561class FloorOperator(Operator):
562 """Test selector for the FLOOR operator."""
563
564 name = "floor"
565
566
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100567class FullyConnectedOperator(Operator):
568 """Test selector for the FULLY_CONNECTED operator."""
569
570 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100571 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100572
573
574class GatherOperator(Operator):
575 """Test selector for the GATHER operator."""
576
577 name = "gather"
578
579
580class GreaterOperator(Operator):
581 """Test selector for the GREATER operator."""
582
583 name = "greater"
584
585 @classmethod
586 def get_test_paths(cls, test_dir: Path, negative):
587 """Generate test paths for this operator."""
588 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
589
590
591class GreaterEqualOperator(Operator):
592 """Test selector for the GREATER_EQUAL operator."""
593
594 name = "greater_equal"
595
596
597class IdentityOperator(Operator):
598 """Test selector for the IDENTITY operator."""
599
600 name = "identity"
601
602
603class IntDivOperator(Operator):
Jeremy Johnson35396f22023-01-04 17:05:25 +0000604 """Test selector for the INTDIV operator."""
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100605
606 name = "intdiv"
607
608
Jeremy Johnson35396f22023-01-04 17:05:25 +0000609class LogOperator(Operator):
610 """Test selector for the LOG operator."""
611
612 name = "log"
613
614
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100615class LogicalAndOperator(Operator):
616 """Test selector for the LOGICAL_AND operator."""
617
618 name = "logical_and"
619
620
621class LogicalLeftShiftOperator(Operator):
622 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
623
624 name = "logical_left_shift"
625
626
627class LogicalNotOperator(Operator):
628 """Test selector for the LOGICAL_NOT operator."""
629
630 name = "logical_not"
631
632
633class LogicalOrOperator(Operator):
634 """Test selector for the LOGICAL_OR operator."""
635
636 name = "logical_or"
637
638
639class LogicalRightShiftOperator(Operator):
640 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
641
642 name = "logical_right_shift"
643
644
645class LogicalXorOperator(Operator):
646 """Test selector for the LOGICAL_XOR operator."""
647
648 name = "logical_xor"
649
650
651class MatmulOperator(Operator):
652 """Test selector for the MATMUL operator."""
653
654 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100655 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100656
657
658class MaximumOperator(Operator):
659 """Test selector for the Maximum operator."""
660
661 name = "maximum"
662
663
664class MaxPool2dOperator(Operator):
665 """Test selector for the MAX_POOL2D operator."""
666
667 name = "max_pool2d"
668 param_names = ["shape", "type", "stride", "kernel", "pad"]
669
670
671class MinimumOperator(Operator):
672 """Test selector for the Minimum operator."""
673
674 name = "minimum"
675
676
677class MulOperator(Operator):
678 """Test selector for the MUL operator."""
679
680 name = "mul"
681 param_names = ["shape", "type", "perm", "shift"]
682
683
684class NegateOperator(Operator):
685 """Test selector for the Negate operator."""
686
687 name = "negate"
688
689
690class PadOperator(Operator):
691 """Test selector for the PAD operator."""
692
693 name = "pad"
694 param_names = ["shape", "type", "pad"]
695
696
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000697class PowOperator(Operator):
698 """Test selector for the POW operator."""
699
700 name = "pow"
701
702
Jeremy Johnson35396f22023-01-04 17:05:25 +0000703class ReciprocalOperator(Operator):
704 """Test selector for the RECIPROCAL operator."""
705
706 name = "reciprocal"
707
708
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100709class ReduceAllOperator(Operator):
710 """Test selector for the REDUCE_ALL operator."""
711
712 name = "reduce_all"
713 param_names = ["shape", "type", "axis"]
714
715
716class ReduceAnyOperator(Operator):
717 """Test selector for the REDUCE_ANY operator."""
718
719 name = "reduce_any"
720 param_names = ["shape", "type", "axis"]
721
722
723class ReduceMaxOperator(Operator):
724 """Test selector for the REDUCE_MAX operator."""
725
726 name = "reduce_max"
727 param_names = ["shape", "type", "axis"]
728
729
730class ReduceMinOperator(Operator):
731 """Test selector for the REDUCE_MIN operator."""
732
733 name = "reduce_min"
734 param_names = ["shape", "type", "axis"]
735
736
James Ward512c1ca2023-01-27 18:46:44 +0000737class ReduceProductOperator(Operator):
738 """Test selector for the REDUCE_PRODUCT operator."""
739
740 name = "reduce_product"
741 param_names = ["shape", "type", "axis"]
742
743
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100744class ReduceSumOperator(Operator):
745 """Test selector for the REDUCE_SUM operator."""
746
747 name = "reduce_sum"
748 param_names = ["shape", "type", "axis"]
749
750
751class RescaleOperator(Operator):
752 """Test selector for the RESCALE operator."""
753
754 name = "rescale"
755 param_names = [
756 "shape",
757 "type",
758 "output_type",
759 "scale",
760 "double_round",
761 "per_channel",
762 ]
763
764
765class ReshapeOperator(Operator):
766 """Test selector for the RESHAPE operator."""
767
768 name = "reshape"
Jerry Ge264f7fa2023-04-21 22:49:57 +0000769 param_names = ["shape", "type", "perm", "rank", "out"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100770
771
772class ResizeOperator(Operator):
773 """Test selector for the RESIZE operator."""
774
775 name = "resize"
776 param_names = [
777 "shape",
778 "type",
779 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100780 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100781 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100782 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100783 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100784 ]
785
786
787class ReverseOperator(Operator):
788 """Test selector for the REVERSE operator."""
789
790 name = "reverse"
791 param_names = ["shape", "type", "axis"]
792
793
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000794class RFFT2DOperator(Operator):
795 """Test selector for the RFFT2D operator."""
796
797 name = "rfft2d"
798
799
Jeremy Johnson35396f22023-01-04 17:05:25 +0000800class RsqrtOperator(Operator):
801 """Test selector for the RSQRT operator."""
802
803 name = "rsqrt"
804
805
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100806class ScatterOperator(Operator):
807 """Test selector for the SCATTER operator."""
808
809 name = "scatter"
810
811
812class SelectOperator(Operator):
813 """Test selector for the SELECT operator."""
814
815 name = "select"
816
817
James Wardb45db9a2022-12-12 13:02:44 +0000818class SigmoidOperator(Operator):
819 """Test selector for the SIGMOID operator."""
820
821 name = "sigmoid"
822
823
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100824class SliceOperator(Operator):
825 """Test selector for the SLICE operator."""
826
827 name = "slice"
828 param_names = ["shape", "type", "perm"]
829
830
831class SubOperator(Operator):
832 """Test selector for the SUB operator."""
833
834 name = "sub"
835
836
837class TableOperator(Operator):
838 """Test selector for the TABLE operator."""
839
840 name = "table"
841
842
James Wardb45db9a2022-12-12 13:02:44 +0000843class TanhOperator(Operator):
844 """Test selector for the TANH operator."""
845
846 name = "tanh"
847
848
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100849class TileOperator(Operator):
850 """Test selector for the TILE operator."""
851
852 name = "tile"
853 param_names = ["shape", "type", "perm"]
854
855
856class TransposeOperator(Operator):
857 """Test selector for the TRANSPOSE operator."""
858
859 name = "transpose"
860 param_names = ["shape", "type", "perm"]
861
862 @classmethod
863 def get_test_paths(cls, test_dir: Path, negative):
864 """Generate test paths for this operator."""
865 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
866
867
868class TransposeConv2dOperator(Operator):
869 """Test selector for the TRANSPOSE_CONV2D operator."""
870
871 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100872 param_names = [
873 "kernel",
874 "shape",
875 "type",
876 "accum_type",
877 "stride",
878 "pad",
879 "out_shape",
880 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100881
882 def path_params(self, path):
883 """Return a dictionary of params from the test path."""
884 params = super().path_params(path)
885 # out_shape is different for every test case, so ignore it for selection
886 params["out_shape"] = ""
887 return params
888
889
890class WhileLoopOperator(Operator):
891 """Test selector for the WHILE_LOOP operator."""
892
893 name = "while_loop"
894 param_names = ["shape", "type", "cond"]
895
896
897def parse_args():
898 """Parse the arguments."""
899 parser = argparse.ArgumentParser()
900 parser.add_argument(
901 "--test-dir",
902 default=Path.cwd(),
903 type=Path,
904 help=(
905 "The directory where test subdirectories for all operators can be found"
906 " (default: current working directory)"
907 ),
908 )
909 parser.add_argument(
910 "--config",
911 default=Path(__file__).with_suffix(".json"),
912 type=Path,
913 help="A JSON file defining the parameters to use for each operator",
914 )
915 parser.add_argument(
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000916 "--selector",
917 default="default",
918 type=str,
919 help="The selector in the selection dictionary to use for each operator",
920 )
921 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100922 "--full-path", action="store_true", help="output the full path for each test"
923 )
924 parser.add_argument(
925 "-v",
926 dest="verbosity",
927 action="count",
928 default=0,
929 help="Verbosity (can be used multiple times for more details)",
930 )
931 parser.add_argument(
932 "operators",
933 type=str,
934 nargs="*",
935 help=(
936 f"Select tests for the specified operator(s)"
937 f" - all operators are assumed if none are specified)"
938 f" - choose from: {[n for n in Operator.registry]}"
939 ),
940 )
941 parser.add_argument(
942 "--test-type",
943 dest="test_type",
944 choices=["positive", "negative"],
945 default="positive",
946 type=str,
947 help="type of tests selected, positive or negative",
948 )
949 return parser.parse_args()
950
951
952def main():
953 """Example test selection."""
954 args = parse_args()
955
956 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +0000957 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100958 logger.info(f"{__file__}: args: {args}")
959
960 try:
961 with open(args.config, "r") as fd:
962 config = json.load(fd)
963 except Exception as e:
964 logger.error(f"Config file error: {e}")
965 return 2
966
967 negative = args.test_type == "negative"
968 for op_name in Operator.registry:
969 if not args.operators or op_name in args.operators:
970 op_params = config[op_name] if op_name in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000971 if "selection" in op_params and args.selector in op_params["selection"]:
972 selection_config = op_params["selection"][args.selector]
973 else:
974 logger.warning(
975 f"Could not find selection config {args.selector} for {op_name}"
976 )
977 selection_config = {}
978 op = Operator.registry[op_name](args.test_dir, selection_config, negative)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100979 for test_path in op.select_tests():
980 print(test_path.resolve() if args.full_path else test_path.name)
981
982 return 0
983
984
985if __name__ == "__main__":
986 exit(main())