blob: b7bbfc3b710cd612918871a958b701cebc645e4d [file] [log] [blame]
Jeremy Johnson35396f22023-01-04 17:05:25 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01002# SPDX-License-Identifier: Apache-2.0
3"""Select generated tests."""
4import argparse
5import itertools
6import json
7import logging
James Ward736fd1a2023-01-23 17:13:37 +00008import re
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01009from pathlib import Path
10from typing import Any
11from typing import Dict
12from typing import List
13
14logging.basicConfig()
15logger = logging.getLogger("test_select")
16
17
18def expand_params(permutes: Dict[str, List[Any]], others: Dict[str, List[Any]]):
19 """Generate permuted combinations of a dictionary of values and combine with others.
20
21 permutes: a dictionary with sequences of values to be fully permuted
22 others: a dictionary with sequences of values not fully permuted, but all used
23
24 This yields dictionaries with one value from each of the items in permutes,
25 combined with one value from each of the items in others.
26
27 Example 1:
28
29 permutes = {"a": [1, 2], "b": [3, 4]}
30 others = {"c": [5, 6, 7], "d" [True, False]}
31
32 generates:
33
34 [
35 {"a": 1, "b": 3, "c": 5, "d": True},
36 {"a": 1, "b": 4, "c": 6, "d": False`},
37 {"a": 2, "b": 3, "c": 7, "d": True},
38 {"a": 2, "b": 4, "c": 5, "d": False`},
39 ]
40
41 Example 2:
42
43 permutes = {"a": [1, 2], "b": [3, 4]}
44 others = {"c": [5, 6, 7, 8, 9], "d" [True, False]}
45
46 generates:
47
48 [
49 {"a": 1, "b": 3, "c": 5, "d": True},
50 {"a": 1, "b": 4, "c": 6, "d": False},
51 {"a": 2, "b": 3, "c": 7, "d": True},
52 {"a": 2, "b": 4, "c": 8, "d": False},
53 {"a": 1, "b": 3, "c": 9, "d": True},
54 ]
55
56 Raises:
57 ValueError if any item is in both permutes and others
58 """
59 for k in permutes:
60 if k in others:
61 raise ValueError(f"item conflict: {k}")
62
63 p_keys = []
64 p_vals = []
65 # if permutes is empty, p_permute_len should be 0, but we leave it as 1
66 # so we return a single, empty dictionary, if others is also empty
67 p_product_len = 1
68 # extract the keys and values from the permutes dictionary
69 # and calulate the product of the number of values in each item as we do so
70 for k, v in permutes.items():
71 p_keys.append(k)
72 p_vals.append(v)
73 p_product_len *= len(v)
74 # create a cyclic generator for the product of all the permuted values
75 p_product = itertools.product(*p_vals)
76 p_generator = itertools.cycle(p_product)
77
78 o_keys = []
79 o_vals = []
80 o_generators = []
81 # extract the keys and values from the others dictionary
82 # and create a cyclic generator for each list of values
83 for k, v in others.items():
84 o_keys.append(k)
85 o_vals.append(v)
86 o_generators.append(itertools.cycle(v))
87
88 # The number of params dictionaries generated will be the maximumum size
89 # of the permuted values and the non-permuted values from others
90 max_items = max([p_product_len] + [len(x) for x in o_vals])
91
92 # create a dictionary with a single value for each of the permutes and others keys
93 for _ in range(max_items):
94 params = {}
95 # add the values for the permutes parameters
96 # the permuted values generator returns a value for each of the permuted keys
97 # in the same order as they were originally given
98 p_vals = next(p_generator)
99 for i in range(len(p_keys)):
100 params[p_keys[i]] = p_vals[i]
101 # add the values for the others parameters
102 # there is a separate generator for each of the others values
103 for i in range(len(o_keys)):
104 params[o_keys[i]] = next(o_generators[i])
105 yield params
106
107
108class Operator:
109 """Base class for operator specific selection properties."""
110
111 # A registry of all Operator subclasses, indexed by the operator name
112 registry = {}
113
114 def __init_subclass__(cls, **kwargs):
115 """Subclass initialiser to register all Operator classes."""
116 super().__init_subclass__(**kwargs)
117 cls.registry[cls.name] = cls
118
119 # Derived classes must override the operator name
120 name = None
121 # Operators with additional parameters must override the param_names
122 # NB: the order must match the order the values appear in the test names
123 param_names = ["shape", "type"]
124
125 # Working set of param_names - updated for negative tests
126 wks_param_names = None
127
128 def __init__(
129 self,
130 test_dir: Path,
131 config: Dict[str, Dict[str, List[Any]]],
132 negative=False,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000133 ignore_missing=False,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100134 ):
135 """Initialise the selection parameters for an operator.
136
James Ward736fd1a2023-01-23 17:13:37 +0000137 test_dir: the directory where the tests for all operators can
138 be found
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100139 config: a dictionary with:
James Ward736fd1a2023-01-23 17:13:37 +0000140 "params" - a dictionary with mappings of parameter
141 names to the values to select (a sub-set of
142 expected values for instance)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100143 "permutes" - a list of parameter names to be permuted
James Ward736fd1a2023-01-23 17:13:37 +0000144 "preselected" - a list of dictionaries containing
145 parameter names and pre-chosen values
146 "sparsity" - a dictionary of parameter names with a
147 sparsity value
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000148 "full_sparsity" - "true"/"false" to use the sparsity
149 value on permutes/params/preselected
James Ward736fd1a2023-01-23 17:13:37 +0000150 "exclude_patterns" - a list of regex's whereby each
151 match will not be considered for selection.
152 Exclusion happens BEFORE test selection (i.e.
153 before permutes are applied).
154 "errorifs" - list of ERRORIF case names to be selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000155 after exclusion (negative tests)
James Ward736fd1a2023-01-23 17:13:37 +0000156 negative: bool indicating if negative testing is being selected
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000157 which filters for ERRORIF in the test name and only selects
158 the first test found (ERRORIF tests)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000159 ignore_missing: bool indicating if missing tests should be ignored
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100160
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000161 EXAMPLE CONFIG (with non-json comments):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 "params": {
163 "output_type": [
164 "outi8",
165 "outb"
166 ]
167 },
168 "permutes": [
169 "shape",
170 "type"
171 ],
172 "sparsity": {
173 "pad": 15
174 },
175 "preselected": [
176 {
177 "shape": "6",
178 "type": "i8",
179 "pad": "pad00"
180 }
181 ],
James Ward736fd1a2023-01-23 17:13:37 +0000182 "exclude_patterns": [
Jeremy Johnsondd3e9aa2023-02-06 16:58:04 +0000183 # Exclude positive (not ERRORIF) integer tests
184 "^((?!ERRORIF).)*_(i8|i16|i32|b)_out(i8|i16|i32|b)",
185 # Exclude negative (ERRORIF) i8 test
186 ".*_ERRORIF_.*_i8_outi8"
James Ward736fd1a2023-01-23 17:13:37 +0000187 ],
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100188 "errorifs": [
189 "InputZeroPointNotZero"
190 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100191 """
192 assert isinstance(
193 self.name, str
194 ), f"{self.__class__.__name__}: {self.name} is not a valid operator name"
195
196 self.negative = negative
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000197 self.ignore_missing = ignore_missing
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100198 self.wks_param_names = self.param_names.copy()
199 if self.negative:
200 # need to override positive set up - use "errorifs" config if set
201 # add in errorif case before shape to support all ops, including
202 # different ops like COND_IF and CONVnD etc
203 index = self.wks_param_names.index("shape")
204 self.wks_param_names[index:index] = ["ERRORIF", "case"]
205 config["params"] = {x: [] for x in self.wks_param_names}
206 config["params"]["case"] = (
207 config["errorifs"] if "errorifs" in config else []
208 )
209 config["permutes"] = []
210 config["preselected"] = {}
211
212 self.params = config["params"] if "params" in config else {}
213 self.permutes = config["permutes"] if "permutes" in config else []
214 self.sparsity = config["sparsity"] if "sparsity" in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000215 self.full_sparsity = (
216 (config["full_sparsity"] == "true") if "full_sparsity" in config else False
217 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100218 self.preselected = config["preselected"] if "preselected" in config else {}
James Ward736fd1a2023-01-23 17:13:37 +0000219 self.exclude_patterns = (
220 config["exclude_patterns"] if "exclude_patterns" in config else []
221 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100222 self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
223 logger.info(f"{self.name}: permutes={self.permutes}")
224 logger.info(f"{self.name}: non_permutes={self.non_permutes}")
James Ward736fd1a2023-01-23 17:13:37 +0000225 logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100226
James Ward736fd1a2023-01-23 17:13:37 +0000227 self.test_paths = []
228 excluded_paths = []
229 for path in self.get_test_paths(test_dir, self.negative):
230 pattern_match = False
231 for pattern in self.exclude_patterns:
232 if re.fullmatch(pattern, path.name):
233 excluded_paths.append(path)
234 pattern_match = True
235 break
236 if not pattern_match:
237 self.test_paths.append(path)
238
239 logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
240
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100241 if not self.test_paths:
242 logger.error(f"no tests found for {self.name} in {test_dir}")
243 logger.debug(f"{self.name}: paths={self.test_paths}")
244
245 # get default parameter values for any not given in the config
246 default_params = self.get_default_params()
247 for param in default_params:
248 if param not in self.params or not self.params[param]:
249 self.params[param] = default_params[param]
250 for param in self.wks_param_names:
251 logger.info(f"{self.name}: params[{param}]={self.params[param]}")
252
253 @staticmethod
254 def _get_test_paths(test_dir: Path, base_dir_glob, path_glob, negative):
255 """Generate test paths for operators using operator specifics."""
256 for base_dir in sorted(test_dir.glob(base_dir_glob)):
257 for path in sorted(base_dir.glob(path_glob)):
258 if (not negative and "ERRORIF" not in str(path)) or (
259 negative and "ERRORIF" in str(path)
260 ):
261 yield path
262
263 @classmethod
264 def get_test_paths(cls, test_dir: Path, negative):
265 """Generate test paths for this operator."""
266 yield from Operator._get_test_paths(test_dir, f"{cls.name}*", "*", negative)
267
268 def path_params(self, path):
269 """Return a dictionary of params from the test path."""
270 params = {}
271 op_name_parts = self.name.split("_")
272 values = path.name.split("_")[len(op_name_parts) :]
273 assert len(values) == len(
274 self.wks_param_names
275 ), f"len({values}) == len({self.wks_param_names})"
276 for i, param in enumerate(self.wks_param_names):
277 params[param] = values[i]
278 return params
279
280 def get_default_params(self):
281 """Get the default parameter values from the test names."""
282 params = {param: set() for param in self.wks_param_names}
283 for path in self.test_paths:
284 path_params = self.path_params(path)
285 for k in params:
286 params[k].add(path_params[k])
287 for param in params:
288 params[param] = sorted(list(params[param]))
289 return params
290
291 def select_tests(self): # noqa: C901 (function too complex)
292 """Generate the paths to the selected tests for this operator."""
293 if not self.test_paths:
294 # Exit early when nothing to select from
295 return
296
297 # the test paths that have not been selected yet
298 unused_paths = set(self.test_paths)
299
300 # a list of dictionaries of unused preselected parameter combinations
301 unused_preselected = [x for x in self.preselected]
302 logger.debug(f"preselected: {unused_preselected}")
303
304 # a list of dictionaries of unused permuted parameter combinations
305 permutes = {k: self.params[k] for k in self.permutes}
306 others = {k: self.params[k] for k in self.non_permutes}
307 unused_permuted = [x for x in expand_params(permutes, others)]
308 logger.debug(f"permuted: {unused_permuted}")
309
310 # a dictionary of sets of unused parameter values
311 if self.negative:
312 # We only care about selecting a test for each errorif case
313 unused_values = {k: set() for k in self.params}
314 unused_values["case"] = set(self.params["case"])
315 else:
316 unused_values = {k: set(v) for k, v in self.params.items()}
317
318 # select tests matching permuted, or preselected, parameter combinations
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000319 for n, path in enumerate(self.test_paths):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100320 path_params = self.path_params(path)
321 if path_params in unused_permuted or path_params in unused_preselected:
322 unused_paths.remove(path)
323 if path_params in unused_preselected:
324 unused_preselected.remove(path_params)
325 if path_params in unused_permuted:
326 unused_permuted.remove(path_params)
327 if self.negative:
328 # remove any other errorif cases, so we only match one
329 for p in list(unused_permuted):
330 if p["case"] == path_params["case"]:
331 unused_permuted.remove(p)
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000332 if self.full_sparsity:
333 # Test for sparsity
334 skip = False
335 for k in path_params:
336 if k in self.sparsity and n % self.sparsity[k] != 0:
337 logger.debug(f"Skipping due to {k} sparsity - {path.name}")
338 skip = True
339 break
340 if skip:
341 continue
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100342 # remove the param values used by this path
343 for k in path_params:
344 unused_values[k].discard(path_params[k])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000345 logger.debug(f"FOUND wanted: {path.name}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100346 yield path
347
348 # search for tests that match any unused parameter values
349 for n, path in enumerate(sorted(list(unused_paths))):
350 path_params = self.path_params(path)
351 # select paths with unused param values
352 # skipping some, if sparsity is set for the param
353 for k in path_params:
354 if path_params[k] in unused_values[k] and (
355 k not in self.sparsity or n % self.sparsity[k] == 0
356 ):
357 # remove the param values used by this path
358 for p in path_params:
359 unused_values[p].discard(path_params[p])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000360 sparsity = self.sparsity[k] if k in self.sparsity else 0
361 logger.debug(f"FOUND unused [{k}/{n}/{sparsity}]: {path.name}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100362 yield path
363 break
364
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000365 if not self.ignore_missing:
366 # report any preselected combinations that were not found
367 for params in unused_preselected:
368 logger.warning(f"MISSING preselected: {params}")
369 # report any permuted combinations that were not found
370 for params in unused_permuted:
371 logger.debug(f"MISSING permutation: {params}")
372 # report any param values that were not found
373 for k, values in unused_values.items():
374 if values:
375 if k not in self.sparsity:
376 logger.warning(
377 f"MISSING {len(values)} values for {k}: {values}"
378 )
379 else:
380 logger.info(
381 f"Skipped {len(values)} values for {k} due to sparsity setting"
382 )
383 logger.debug(f"Values skipped: {values}")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100384
385
386class AbsOperator(Operator):
387 """Test selector for the ABS operator."""
388
389 name = "abs"
390
391
392class ArithmeticRightShiftOperator(Operator):
393 """Test selector for the Arithmetic Right Shift operator."""
394
395 name = "arithmetic_right_shift"
396 param_names = ["shape", "type", "rounding"]
397
398
399class AddOperator(Operator):
400 """Test selector for the ADD operator."""
401
402 name = "add"
403
404
405class ArgmaxOperator(Operator):
406 """Test selector for the ARGMAX operator."""
407
408 name = "argmax"
409 param_names = ["shape", "type", "axis"]
410
411
412class AvgPool2dOperator(Operator):
413 """Test selector for the AVG_POOL2D operator."""
414
415 name = "avg_pool2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100416 param_names = ["shape", "type", "accum_type", "stride", "kernel", "pad"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100417
418
419class BitwiseAndOperator(Operator):
420 """Test selector for the BITWISE_AND operator."""
421
422 name = "bitwise_and"
423
424
425class BitwiseNotOperator(Operator):
426 """Test selector for the BITWISE_NOT operator."""
427
428 name = "bitwise_not"
429
430
431class BitwiseOrOperator(Operator):
432 """Test selector for the BITWISE_OR operator."""
433
434 name = "bitwise_or"
435
436
437class BitwiseXorOperator(Operator):
438 """Test selector for the BITWISE_XOR operator."""
439
440 name = "bitwise_xor"
441
442
443class CastOperator(Operator):
444 """Test selector for the CAST operator."""
445
446 name = "cast"
447 param_names = ["shape", "type", "output_type"]
448
449
James Ward71616fe2022-11-23 11:00:47 +0000450class CeilOperator(Operator):
451 """Test selector for the CEIL operator."""
452
453 name = "ceil"
454
455
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100456class ClampOperator(Operator):
457 """Test selector for the CLAMP operator."""
458
459 name = "clamp"
460
461
462class CLZOperator(Operator):
463 """Test selector for the CLZ operator."""
464
465 name = "clz"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100466
467
468class ConcatOperator(Operator):
469 """Test selector for the CONCAT operator."""
470
471 name = "concat"
472 param_names = ["shape", "type", "axis"]
473
474
475class CondIfOperator(Operator):
476 """Test selector for the COND_IF operator."""
477
478 name = "cond_if"
479 param_names = ["variant", "shape", "type", "cond"]
480
481
482class ConstOperator(Operator):
483 """Test selector for the CONST operator."""
484
485 name = "const"
486
487
488class Conv2dOperator(Operator):
489 """Test selector for the CONV2D operator."""
490
491 name = "conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100492 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100493
494
495class Conv3dOperator(Operator):
496 """Test selector for the CONV3D operator."""
497
498 name = "conv3d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100499 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100500
501
502class DepthwiseConv2dOperator(Operator):
503 """Test selector for the DEPTHWISE_CONV2D operator."""
504
505 name = "depthwise_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100506 param_names = ["kernel", "shape", "type", "accum_type", "stride", "pad", "dilation"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100507
508
Won Jeona21b2e82023-08-10 10:33:01 +0000509class DimOeprator(Operator):
510 """Test selector for the DIM operator."""
511
512 name = "dim"
513 param_names = ["shape", "type", "axis"]
514
515
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100516class EqualOperator(Operator):
517 """Test selector for the EQUAL operator."""
518
519 name = "equal"
520
521
Jeremy Johnson35396f22023-01-04 17:05:25 +0000522class ExpOperator(Operator):
523 """Test selector for the EXP operator."""
524
525 name = "exp"
526
527
Won Jeon78155c62023-06-10 00:20:04 +0000528class ErfOperator(Operator):
529 """Test selector for the ERF operator."""
530
531 name = "erf"
532
533
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000534class FFT2DOperator(Operator):
535 """Test selector for the FFT2D operator."""
536
537 name = "fft2d"
538 param_names = ["shape", "type", "inverse"]
539
540
James Ward71616fe2022-11-23 11:00:47 +0000541class FloorOperator(Operator):
542 """Test selector for the FLOOR operator."""
543
544 name = "floor"
545
546
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100547class FullyConnectedOperator(Operator):
548 """Test selector for the FULLY_CONNECTED operator."""
549
550 name = "fully_connected"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100551 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100552
553
554class GatherOperator(Operator):
555 """Test selector for the GATHER operator."""
556
557 name = "gather"
558
559
560class GreaterOperator(Operator):
561 """Test selector for the GREATER operator."""
562
563 name = "greater"
564
565 @classmethod
566 def get_test_paths(cls, test_dir: Path, negative):
567 """Generate test paths for this operator."""
568 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
569
570
571class GreaterEqualOperator(Operator):
572 """Test selector for the GREATER_EQUAL operator."""
573
574 name = "greater_equal"
575
576
577class IdentityOperator(Operator):
578 """Test selector for the IDENTITY operator."""
579
580 name = "identity"
581
582
583class IntDivOperator(Operator):
Jeremy Johnson35396f22023-01-04 17:05:25 +0000584 """Test selector for the INTDIV operator."""
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100585
586 name = "intdiv"
587
588
Jeremy Johnson35396f22023-01-04 17:05:25 +0000589class LogOperator(Operator):
590 """Test selector for the LOG operator."""
591
592 name = "log"
593
594
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100595class LogicalAndOperator(Operator):
596 """Test selector for the LOGICAL_AND operator."""
597
598 name = "logical_and"
599
600
601class LogicalLeftShiftOperator(Operator):
602 """Test selector for the LOGICAL_LEFT_SHIFT operator."""
603
604 name = "logical_left_shift"
605
606
607class LogicalNotOperator(Operator):
608 """Test selector for the LOGICAL_NOT operator."""
609
610 name = "logical_not"
611
612
613class LogicalOrOperator(Operator):
614 """Test selector for the LOGICAL_OR operator."""
615
616 name = "logical_or"
617
618
619class LogicalRightShiftOperator(Operator):
620 """Test selector for the LOGICAL_RIGHT_SHIFT operator."""
621
622 name = "logical_right_shift"
623
624
625class LogicalXorOperator(Operator):
626 """Test selector for the LOGICAL_XOR operator."""
627
628 name = "logical_xor"
629
630
631class MatmulOperator(Operator):
632 """Test selector for the MATMUL operator."""
633
634 name = "matmul"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100635 param_names = ["shape", "type", "accum_type"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100636
637
638class MaximumOperator(Operator):
639 """Test selector for the Maximum operator."""
640
641 name = "maximum"
642
643
644class MaxPool2dOperator(Operator):
645 """Test selector for the MAX_POOL2D operator."""
646
647 name = "max_pool2d"
648 param_names = ["shape", "type", "stride", "kernel", "pad"]
649
650
651class MinimumOperator(Operator):
652 """Test selector for the Minimum operator."""
653
654 name = "minimum"
655
656
657class MulOperator(Operator):
658 """Test selector for the MUL operator."""
659
660 name = "mul"
661 param_names = ["shape", "type", "perm", "shift"]
662
663
664class NegateOperator(Operator):
665 """Test selector for the Negate operator."""
666
667 name = "negate"
668
669
670class PadOperator(Operator):
671 """Test selector for the PAD operator."""
672
673 name = "pad"
674 param_names = ["shape", "type", "pad"]
675
676
Jeremy Johnson6ffb7c82022-12-05 16:59:28 +0000677class PowOperator(Operator):
678 """Test selector for the POW operator."""
679
680 name = "pow"
681
682
Jeremy Johnson35396f22023-01-04 17:05:25 +0000683class ReciprocalOperator(Operator):
684 """Test selector for the RECIPROCAL operator."""
685
686 name = "reciprocal"
687
688
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100689class ReduceAllOperator(Operator):
690 """Test selector for the REDUCE_ALL operator."""
691
692 name = "reduce_all"
693 param_names = ["shape", "type", "axis"]
694
695
696class ReduceAnyOperator(Operator):
697 """Test selector for the REDUCE_ANY operator."""
698
699 name = "reduce_any"
700 param_names = ["shape", "type", "axis"]
701
702
703class ReduceMaxOperator(Operator):
704 """Test selector for the REDUCE_MAX operator."""
705
706 name = "reduce_max"
707 param_names = ["shape", "type", "axis"]
708
709
710class ReduceMinOperator(Operator):
711 """Test selector for the REDUCE_MIN operator."""
712
713 name = "reduce_min"
714 param_names = ["shape", "type", "axis"]
715
716
James Ward512c1ca2023-01-27 18:46:44 +0000717class ReduceProductOperator(Operator):
718 """Test selector for the REDUCE_PRODUCT operator."""
719
720 name = "reduce_product"
721 param_names = ["shape", "type", "axis"]
722
723
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100724class ReduceSumOperator(Operator):
725 """Test selector for the REDUCE_SUM operator."""
726
727 name = "reduce_sum"
728 param_names = ["shape", "type", "axis"]
729
730
731class RescaleOperator(Operator):
732 """Test selector for the RESCALE operator."""
733
734 name = "rescale"
735 param_names = [
736 "shape",
737 "type",
738 "output_type",
739 "scale",
740 "double_round",
741 "per_channel",
742 ]
743
744
745class ReshapeOperator(Operator):
746 """Test selector for the RESHAPE operator."""
747
748 name = "reshape"
Jerry Ge264f7fa2023-04-21 22:49:57 +0000749 param_names = ["shape", "type", "perm", "rank", "out"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100750
751
752class ResizeOperator(Operator):
753 """Test selector for the RESIZE operator."""
754
755 name = "resize"
756 param_names = [
757 "shape",
758 "type",
759 "mode",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100760 "output_type",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100761 "scale",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100762 "offset",
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100763 "border",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100764 ]
765
766
767class ReverseOperator(Operator):
768 """Test selector for the REVERSE operator."""
769
770 name = "reverse"
771 param_names = ["shape", "type", "axis"]
772
773
Jeremy Johnsonc5d75932023-02-14 11:47:46 +0000774class RFFT2DOperator(Operator):
775 """Test selector for the RFFT2D operator."""
776
777 name = "rfft2d"
778
779
Jeremy Johnson35396f22023-01-04 17:05:25 +0000780class RsqrtOperator(Operator):
781 """Test selector for the RSQRT operator."""
782
783 name = "rsqrt"
784
785
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100786class ScatterOperator(Operator):
787 """Test selector for the SCATTER operator."""
788
789 name = "scatter"
790
791
792class SelectOperator(Operator):
793 """Test selector for the SELECT operator."""
794
795 name = "select"
796
797
James Wardb45db9a2022-12-12 13:02:44 +0000798class SigmoidOperator(Operator):
799 """Test selector for the SIGMOID operator."""
800
801 name = "sigmoid"
802
803
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100804class SliceOperator(Operator):
805 """Test selector for the SLICE operator."""
806
807 name = "slice"
808 param_names = ["shape", "type", "perm"]
809
810
811class SubOperator(Operator):
812 """Test selector for the SUB operator."""
813
814 name = "sub"
815
816
817class TableOperator(Operator):
818 """Test selector for the TABLE operator."""
819
820 name = "table"
821
822
James Wardb45db9a2022-12-12 13:02:44 +0000823class TanhOperator(Operator):
824 """Test selector for the TANH operator."""
825
826 name = "tanh"
827
828
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100829class TileOperator(Operator):
830 """Test selector for the TILE operator."""
831
832 name = "tile"
833 param_names = ["shape", "type", "perm"]
834
835
836class TransposeOperator(Operator):
837 """Test selector for the TRANSPOSE operator."""
838
839 name = "transpose"
840 param_names = ["shape", "type", "perm"]
841
842 @classmethod
843 def get_test_paths(cls, test_dir: Path, negative):
844 """Generate test paths for this operator."""
845 yield from Operator._get_test_paths(test_dir, f"{cls.name}", "*", negative)
846
847
848class TransposeConv2dOperator(Operator):
849 """Test selector for the TRANSPOSE_CONV2D operator."""
850
851 name = "transpose_conv2d"
Jeremy Johnson93d43902022-09-27 12:26:14 +0100852 param_names = [
853 "kernel",
854 "shape",
855 "type",
856 "accum_type",
857 "stride",
858 "pad",
859 "out_shape",
860 ]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100861
862 def path_params(self, path):
863 """Return a dictionary of params from the test path."""
864 params = super().path_params(path)
865 # out_shape is different for every test case, so ignore it for selection
866 params["out_shape"] = ""
867 return params
868
869
870class WhileLoopOperator(Operator):
871 """Test selector for the WHILE_LOOP operator."""
872
873 name = "while_loop"
874 param_names = ["shape", "type", "cond"]
875
876
877def parse_args():
878 """Parse the arguments."""
879 parser = argparse.ArgumentParser()
880 parser.add_argument(
881 "--test-dir",
882 default=Path.cwd(),
883 type=Path,
884 help=(
885 "The directory where test subdirectories for all operators can be found"
886 " (default: current working directory)"
887 ),
888 )
889 parser.add_argument(
890 "--config",
891 default=Path(__file__).with_suffix(".json"),
892 type=Path,
893 help="A JSON file defining the parameters to use for each operator",
894 )
895 parser.add_argument(
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000896 "--selector",
897 default="default",
898 type=str,
899 help="The selector in the selection dictionary to use for each operator",
900 )
901 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100902 "--full-path", action="store_true", help="output the full path for each test"
903 )
904 parser.add_argument(
905 "-v",
906 dest="verbosity",
907 action="count",
908 default=0,
909 help="Verbosity (can be used multiple times for more details)",
910 )
911 parser.add_argument(
912 "operators",
913 type=str,
914 nargs="*",
915 help=(
916 f"Select tests for the specified operator(s)"
917 f" - all operators are assumed if none are specified)"
918 f" - choose from: {[n for n in Operator.registry]}"
919 ),
920 )
921 parser.add_argument(
922 "--test-type",
923 dest="test_type",
924 choices=["positive", "negative"],
925 default="positive",
926 type=str,
927 help="type of tests selected, positive or negative",
928 )
929 return parser.parse_args()
930
931
932def main():
933 """Example test selection."""
934 args = parse_args()
935
936 loglevels = (logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
James Ward635bc992022-11-23 11:55:32 +0000937 logger.setLevel(loglevels[min(args.verbosity, len(loglevels) - 1)])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100938 logger.info(f"{__file__}: args: {args}")
939
940 try:
941 with open(args.config, "r") as fd:
942 config = json.load(fd)
943 except Exception as e:
944 logger.error(f"Config file error: {e}")
945 return 2
946
947 negative = args.test_type == "negative"
948 for op_name in Operator.registry:
949 if not args.operators or op_name in args.operators:
950 op_params = config[op_name] if op_name in config else {}
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000951 if "selection" in op_params and args.selector in op_params["selection"]:
952 selection_config = op_params["selection"][args.selector]
953 else:
954 logger.warning(
955 f"Could not find selection config {args.selector} for {op_name}"
956 )
957 selection_config = {}
958 op = Operator.registry[op_name](args.test_dir, selection_config, negative)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100959 for test_path in op.select_tests():
960 print(test_path.resolve() if args.full_path else test_path.name)
961
962 return 0
963
964
965if __name__ == "__main__":
966 exit(main())