blob: bc1ec8e6509beb16cacaaece60a1d0a963665472 [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07003import argparse
Eric Kunzee5e26762020-10-13 16:11:07 -07004import re
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +01005import sys
Eric Kunzee5e26762020-10-13 16:11:07 -07006
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007from generator.tosa_test_gen import TosaTestGen
8from serializer.tosa_serializer import dtype_str_to_val
James Ward24dbc422022-10-19 12:20:31 +01009from serializer.tosa_serializer import DTypeNames
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010011OPTION_FP_VALUES_RANGE = "--fp-values-range"
12
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Eric Kunzee5e26762020-10-13 16:11:07 -070014# Used for parsing a comma-separated list of integers in a string
15# to an actual list of integers
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010016def str_to_list(in_s, is_float=False):
Kevin Cheng550ccc52021-03-03 11:21:43 -080017 """Converts a comma-separated list of string integers to a python list of ints"""
18 lst = in_s.split(",")
Eric Kunzee5e26762020-10-13 16:11:07 -070019 out_list = []
20 for i in lst:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010021 val = float(i) if is_float else int(i)
22 out_list.append(val)
Eric Kunzee5e26762020-10-13 16:11:07 -070023 return out_list
24
Kevin Cheng550ccc52021-03-03 11:21:43 -080025
Eric Kunzee5e26762020-10-13 16:11:07 -070026def auto_int(x):
Kevin Cheng550ccc52021-03-03 11:21:43 -080027 """Converts hex/dec argument values to an int"""
Eric Kunzee5e26762020-10-13 16:11:07 -070028 return int(x, 0)
29
Kevin Cheng550ccc52021-03-03 11:21:43 -080030
Jeremy Johnson00423432022-09-12 17:27:37 +010031def parseArgs(argv):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010032 """Parse the command line arguments."""
33 if argv is None:
34 argv = sys.argv[1:]
35
36 if OPTION_FP_VALUES_RANGE in argv:
37 # Argparse fix for hyphen (minus values) in argument values
38 # convert "ARG VAL" into "ARG=VAL"
39 # Example --fp-values-range -2.0,2.0 -> --fp-values-range=-2.0,2.0
40 new_argv = []
41 idx = 0
42 while idx < len(argv):
43 arg = argv[idx]
44 if arg == OPTION_FP_VALUES_RANGE and idx + 1 < len(argv):
45 val = argv[idx + 1]
46 if val.startswith("-"):
47 arg = f"{arg}={val}"
48 idx += 1
49 new_argv.append(arg)
50 idx += 1
51 argv = new_argv
Eric Kunzee5e26762020-10-13 16:11:07 -070052
53 parser = argparse.ArgumentParser()
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 parser.add_argument(
55 "-o", dest="output_dir", type=str, default="vtest", help="Test output directory"
56 )
Eric Kunzee5e26762020-10-13 16:11:07 -070057
Kevin Cheng550ccc52021-03-03 11:21:43 -080058 parser.add_argument(
59 "--seed",
60 dest="random_seed",
61 default=42,
62 type=int,
63 help="Random seed for test generation",
64 )
Eric Kunzee5e26762020-10-13 16:11:07 -070065
Kevin Cheng550ccc52021-03-03 11:21:43 -080066 parser.add_argument(
67 "--filter",
68 dest="filter",
69 default="",
70 type=str,
71 help="Filter operator test names by this expression",
72 )
Eric Kunzee5e26762020-10-13 16:11:07 -070073
Kevin Cheng550ccc52021-03-03 11:21:43 -080074 parser.add_argument(
75 "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
76 )
Eric Kunzee5e26762020-10-13 16:11:07 -070077
78 # Constraints on tests
Kevin Cheng550ccc52021-03-03 11:21:43 -080079 parser.add_argument(
80 "--tensor-dim-range",
81 dest="tensor_shape_range",
82 default="1,64",
83 type=lambda x: str_to_list(x),
84 help="Min,Max range of tensor shapes",
85 )
Eric Kunzee5e26762020-10-13 16:11:07 -070086
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 parser.add_argument(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010088 OPTION_FP_VALUES_RANGE,
89 dest="tensor_fp_value_range",
90 default="0.0,1.0",
91 type=lambda x: str_to_list(x, is_float=True),
92 help="Min,Max range of floating point tensor values",
93 )
94
95 parser.add_argument(
Kevin Cheng550ccc52021-03-03 11:21:43 -080096 "--max-batch-size",
97 dest="max_batch_size",
98 default=1,
99 type=int,
100 help="Maximum batch size for NHWC tests",
101 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700102
Kevin Cheng550ccc52021-03-03 11:21:43 -0800103 parser.add_argument(
104 "--max-conv-padding",
105 dest="max_conv_padding",
106 default=1,
107 type=int,
108 help="Maximum padding for Conv tests",
109 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700110
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 parser.add_argument(
112 "--max-conv-dilation",
113 dest="max_conv_dilation",
114 default=2,
115 type=int,
116 help="Maximum dilation for Conv tests",
117 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700118
Kevin Cheng550ccc52021-03-03 11:21:43 -0800119 parser.add_argument(
120 "--max-conv-stride",
121 dest="max_conv_stride",
122 default=2,
123 type=int,
124 help="Maximum stride for Conv tests",
125 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
Kevin Cheng550ccc52021-03-03 11:21:43 -0800127 parser.add_argument(
128 "--max-pooling-padding",
129 dest="max_pooling_padding",
130 default=1,
131 type=int,
132 help="Maximum padding for pooling tests",
133 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700134
Kevin Cheng550ccc52021-03-03 11:21:43 -0800135 parser.add_argument(
136 "--max-pooling-stride",
137 dest="max_pooling_stride",
138 default=2,
139 type=int,
140 help="Maximum stride for pooling tests",
141 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700142
Kevin Cheng550ccc52021-03-03 11:21:43 -0800143 parser.add_argument(
144 "--max-pooling-kernel",
145 dest="max_pooling_kernel",
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000146 default=3,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800147 type=int,
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000148 help="Maximum kernel for pooling tests",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800149 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700150
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 parser.add_argument(
152 "--num-rand-permutations",
153 dest="num_rand_permutations",
154 default=6,
155 type=int,
156 help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces",
157 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100159 parser.add_argument(
160 "--max-resize-output-dim",
161 dest="max_resize_output_dim",
162 default=1000,
163 type=int,
164 help="Upper limit on width and height output dimensions for `resize` op. Default: 1000",
165 )
166
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100167 # Targeting a specific shape/rank/dtype
Kevin Cheng550ccc52021-03-03 11:21:43 -0800168 parser.add_argument(
169 "--target-shape",
170 dest="target_shapes",
171 action="append",
172 default=[],
173 type=lambda x: str_to_list(x),
174 help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
175 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700176
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177 parser.add_argument(
178 "--target-rank",
179 dest="target_ranks",
180 action="append",
181 default=None,
182 type=lambda x: auto_int(x),
183 help="Create tests with a particular input tensor rank",
184 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
James Ward24dbc422022-10-19 12:20:31 +0100186 # Used for parsing a comma-separated list of integers in a string
Kevin Cheng550ccc52021-03-03 11:21:43 -0800187 parser.add_argument(
188 "--target-dtype",
189 dest="target_dtypes",
190 action="append",
191 default=None,
192 type=lambda x: dtype_str_to_val(x),
James Ward24dbc422022-10-19 12:20:31 +0100193 help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800194 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700195
Matthew Haddon818ab902021-07-27 09:12:49 +0100196 parser.add_argument(
197 "--num-const-inputs-concat",
198 dest="num_const_inputs_concat",
199 default=0,
200 choices=[0, 1, 2, 3],
201 type=int,
202 help="Allow constant input tensors for concat operator",
203 )
204
Matthew Haddon74567092021-07-16 15:38:20 +0100205 parser.add_argument(
206 "--test-type",
207 dest="test_type",
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000208 choices=["positive", "negative", "both"],
Matthew Haddon74567092021-07-16 15:38:20 +0100209 default="positive",
210 type=str,
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000211 help="type of tests produced, positive, negative, or both",
Matthew Haddon74567092021-07-16 15:38:20 +0100212 )
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000213
214 parser.add_argument(
215 "--allow-pooling-and-conv-oversizes",
216 dest="oversize",
Jeremy Johnsonae0c1c62022-02-10 17:27:34 +0000217 action="store_true",
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000218 help="allow oversize padding, stride and kernel tests",
219 )
220
Jeremy Johnson00423432022-09-12 17:27:37 +0100221 parser.add_argument(
222 "--zero-point",
223 dest="zeropoint",
224 default=None,
225 type=int,
226 help="set a particular zero point for all valid positive tests",
227 )
228
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100229 parser.add_argument(
230 "--dump-const-tensors",
231 dest="dump_consts",
232 action="store_true",
233 help="output const tensors as numpy files for inspection",
234 )
235
Jeremy Johnson00423432022-09-12 17:27:37 +0100236 args = parser.parse_args(argv)
Eric Kunzee5e26762020-10-13 16:11:07 -0700237
238 return args
239
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
Jeremy Johnson00423432022-09-12 17:27:37 +0100241def main(argv=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700242
Jeremy Johnson00423432022-09-12 17:27:37 +0100243 args = parseArgs(argv)
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
245 ttg = TosaTestGen(args)
246
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000247 if args.test_type == "both":
248 testType = ["positive", "negative"]
Matthew Haddon1c00b712021-10-01 15:51:03 +0100249 else:
250 testType = [args.test_type]
Matthew Haddon74567092021-07-16 15:38:20 +0100251 results = []
Matthew Haddon1c00b712021-10-01 15:51:03 +0100252 for test_type in testType:
253 testList = []
254 for op in ttg.TOSA_OP_LIST:
255 if re.match(args.filter + ".*", op):
256 testList.extend(
257 ttg.genOpTestList(
258 op,
259 shapeFilter=args.target_shapes,
260 rankFilter=args.target_ranks,
261 dtypeFilter=args.target_dtypes,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000262 testType=test_type,
Matthew Haddon1c00b712021-10-01 15:51:03 +0100263 )
264 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100265
Matthew Haddon1c00b712021-10-01 15:51:03 +0100266 print("{} matching {} tests".format(len(testList), test_type))
267
268 testStrings = []
269 for opName, testStr, dtype, error, shapeList, testArgs in testList:
270 # Check for and skip duplicate tests
271 if testStr in testStrings:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100272 print(f"Skipping duplicate test: {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +0100273 continue
274 else:
275 testStrings.append(testStr)
276
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000277 results.append(
278 ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)
279 )
Matthew Haddon74567092021-07-16 15:38:20 +0100280
281 print(f"Done creating {len(results)} tests")
282
Eric Kunzee5e26762020-10-13 16:11:07 -0700283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284if __name__ == "__main__":
Eric Kunzee5e26762020-10-13 16:11:07 -0700285 exit(main())