blob: 6ee873f50981d6bdea1641434465d145bb937aaa [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07003import argparse
Eric Kunzee5e26762020-10-13 16:11:07 -07004import re
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson2ec34942021-12-14 16:34:05 +00006from generator.tosa_test_gen import TosaTestGen
7from serializer.tosa_serializer import dtype_str_to_val
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00009
Eric Kunzee5e26762020-10-13 16:11:07 -070010# Used for parsing a comma-separated list of integers in a string
11# to an actual list of integers
12def str_to_list(in_s):
Kevin Cheng550ccc52021-03-03 11:21:43 -080013 """Converts a comma-separated list of string integers to a python list of ints"""
14 lst = in_s.split(",")
Eric Kunzee5e26762020-10-13 16:11:07 -070015 out_list = []
16 for i in lst:
17 out_list.append(int(i))
18 return out_list
19
Kevin Cheng550ccc52021-03-03 11:21:43 -080020
Eric Kunzee5e26762020-10-13 16:11:07 -070021def auto_int(x):
Kevin Cheng550ccc52021-03-03 11:21:43 -080022 """Converts hex/dec argument values to an int"""
Eric Kunzee5e26762020-10-13 16:11:07 -070023 return int(x, 0)
24
Kevin Cheng550ccc52021-03-03 11:21:43 -080025
Eric Kunzee5e26762020-10-13 16:11:07 -070026def parseArgs():
27
28 parser = argparse.ArgumentParser()
Kevin Cheng550ccc52021-03-03 11:21:43 -080029 parser.add_argument(
30 "-o", dest="output_dir", type=str, default="vtest", help="Test output directory"
31 )
Eric Kunzee5e26762020-10-13 16:11:07 -070032
Kevin Cheng550ccc52021-03-03 11:21:43 -080033 parser.add_argument(
34 "--seed",
35 dest="random_seed",
36 default=42,
37 type=int,
38 help="Random seed for test generation",
39 )
Eric Kunzee5e26762020-10-13 16:11:07 -070040
Kevin Cheng550ccc52021-03-03 11:21:43 -080041 parser.add_argument(
42 "--filter",
43 dest="filter",
44 default="",
45 type=str,
46 help="Filter operator test names by this expression",
47 )
Eric Kunzee5e26762020-10-13 16:11:07 -070048
Kevin Cheng550ccc52021-03-03 11:21:43 -080049 parser.add_argument(
50 "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
51 )
Eric Kunzee5e26762020-10-13 16:11:07 -070052
53 # Constraints on tests
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 parser.add_argument(
55 "--tensor-dim-range",
56 dest="tensor_shape_range",
57 default="1,64",
58 type=lambda x: str_to_list(x),
59 help="Min,Max range of tensor shapes",
60 )
Eric Kunzee5e26762020-10-13 16:11:07 -070061
Kevin Cheng550ccc52021-03-03 11:21:43 -080062 parser.add_argument(
63 "--max-batch-size",
64 dest="max_batch_size",
65 default=1,
66 type=int,
67 help="Maximum batch size for NHWC tests",
68 )
Eric Kunzee5e26762020-10-13 16:11:07 -070069
Kevin Cheng550ccc52021-03-03 11:21:43 -080070 parser.add_argument(
71 "--max-conv-padding",
72 dest="max_conv_padding",
73 default=1,
74 type=int,
75 help="Maximum padding for Conv tests",
76 )
Eric Kunzee5e26762020-10-13 16:11:07 -070077
Kevin Cheng550ccc52021-03-03 11:21:43 -080078 parser.add_argument(
79 "--max-conv-dilation",
80 dest="max_conv_dilation",
81 default=2,
82 type=int,
83 help="Maximum dilation for Conv tests",
84 )
Eric Kunzee5e26762020-10-13 16:11:07 -070085
Kevin Cheng550ccc52021-03-03 11:21:43 -080086 parser.add_argument(
87 "--max-conv-stride",
88 dest="max_conv_stride",
89 default=2,
90 type=int,
91 help="Maximum stride for Conv tests",
92 )
Eric Kunzee5e26762020-10-13 16:11:07 -070093
Kevin Cheng550ccc52021-03-03 11:21:43 -080094 parser.add_argument(
95 "--max-pooling-padding",
96 dest="max_pooling_padding",
97 default=1,
98 type=int,
99 help="Maximum padding for pooling tests",
100 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700101
Kevin Cheng550ccc52021-03-03 11:21:43 -0800102 parser.add_argument(
103 "--max-pooling-stride",
104 dest="max_pooling_stride",
105 default=2,
106 type=int,
107 help="Maximum stride for pooling tests",
108 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700109
Kevin Cheng550ccc52021-03-03 11:21:43 -0800110 parser.add_argument(
111 "--max-pooling-kernel",
112 dest="max_pooling_kernel",
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000113 default=3,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800114 type=int,
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000115 help="Maximum kernel for pooling tests",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
Kevin Cheng550ccc52021-03-03 11:21:43 -0800118 parser.add_argument(
119 "--num-rand-permutations",
120 dest="num_rand_permutations",
121 default=6,
122 type=int,
123 help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces",
124 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700125
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100126 parser.add_argument(
127 "--max-resize-output-dim",
128 dest="max_resize_output_dim",
129 default=1000,
130 type=int,
131 help="Upper limit on width and height output dimensions for `resize` op. Default: 1000",
132 )
133
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 # Targetting a specific shape/rank/dtype
Kevin Cheng550ccc52021-03-03 11:21:43 -0800135 parser.add_argument(
136 "--target-shape",
137 dest="target_shapes",
138 action="append",
139 default=[],
140 type=lambda x: str_to_list(x),
141 help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
142 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700143
Kevin Cheng550ccc52021-03-03 11:21:43 -0800144 parser.add_argument(
145 "--target-rank",
146 dest="target_ranks",
147 action="append",
148 default=None,
149 type=lambda x: auto_int(x),
150 help="Create tests with a particular input tensor rank",
151 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 parser.add_argument(
154 "--target-dtype",
155 dest="target_dtypes",
156 action="append",
157 default=None,
158 type=lambda x: dtype_str_to_val(x),
159 help="Create test with a particular DType (may be repeated)",
160 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700161
Matthew Haddon818ab902021-07-27 09:12:49 +0100162 parser.add_argument(
163 "--num-const-inputs-concat",
164 dest="num_const_inputs_concat",
165 default=0,
166 choices=[0, 1, 2, 3],
167 type=int,
168 help="Allow constant input tensors for concat operator",
169 )
170
Matthew Haddon74567092021-07-16 15:38:20 +0100171 parser.add_argument(
172 "--test-type",
173 dest="test_type",
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000174 choices=["positive", "negative", "both"],
Matthew Haddon74567092021-07-16 15:38:20 +0100175 default="positive",
176 type=str,
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000177 help="type of tests produced, positive, negative, or both",
Matthew Haddon74567092021-07-16 15:38:20 +0100178 )
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000179
180 parser.add_argument(
181 "--allow-pooling-and-conv-oversizes",
182 dest="oversize",
Jeremy Johnsonae0c1c62022-02-10 17:27:34 +0000183 action="store_true",
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000184 help="allow oversize padding, stride and kernel tests",
185 )
186
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 args = parser.parse_args()
188
189 return args
190
Eric Kunzee5e26762020-10-13 16:11:07 -0700191
Kevin Cheng550ccc52021-03-03 11:21:43 -0800192def main():
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
194 args = parseArgs()
195
196 ttg = TosaTestGen(args)
197
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000198 if args.test_type == "both":
199 testType = ["positive", "negative"]
Matthew Haddon1c00b712021-10-01 15:51:03 +0100200 else:
201 testType = [args.test_type]
Matthew Haddon74567092021-07-16 15:38:20 +0100202 results = []
Matthew Haddon1c00b712021-10-01 15:51:03 +0100203 for test_type in testType:
204 testList = []
205 for op in ttg.TOSA_OP_LIST:
206 if re.match(args.filter + ".*", op):
207 testList.extend(
208 ttg.genOpTestList(
209 op,
210 shapeFilter=args.target_shapes,
211 rankFilter=args.target_ranks,
212 dtypeFilter=args.target_dtypes,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000213 testType=test_type,
Matthew Haddon1c00b712021-10-01 15:51:03 +0100214 )
215 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100216
Matthew Haddon1c00b712021-10-01 15:51:03 +0100217 print("{} matching {} tests".format(len(testList), test_type))
218
219 testStrings = []
220 for opName, testStr, dtype, error, shapeList, testArgs in testList:
221 # Check for and skip duplicate tests
222 if testStr in testStrings:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100223 print(f"Skipping duplicate test: {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +0100224 continue
225 else:
226 testStrings.append(testStr)
227
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000228 results.append(
229 ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)
230 )
Matthew Haddon74567092021-07-16 15:38:20 +0100231
232 print(f"Done creating {len(results)} tests")
233
Eric Kunzee5e26762020-10-13 16:11:07 -0700234
Kevin Cheng550ccc52021-03-03 11:21:43 -0800235if __name__ == "__main__":
Eric Kunzee5e26762020-10-13 16:11:07 -0700236 exit(main())