blob: 69322cc0cd18e270bf195f2d5f674e962280cd1b [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07003import argparse
Eric Kunzee5e26762020-10-13 16:11:07 -07004import re
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson2ec34942021-12-14 16:34:05 +00006from generator.tosa_test_gen import TosaTestGen
7from serializer.tosa_serializer import dtype_str_to_val
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00009
Eric Kunzee5e26762020-10-13 16:11:07 -070010# Used for parsing a comma-separated list of integers in a string
11# to an actual list of integers
12def str_to_list(in_s):
Kevin Cheng550ccc52021-03-03 11:21:43 -080013 """Converts a comma-separated list of string integers to a python list of ints"""
14 lst = in_s.split(",")
Eric Kunzee5e26762020-10-13 16:11:07 -070015 out_list = []
16 for i in lst:
17 out_list.append(int(i))
18 return out_list
19
Kevin Cheng550ccc52021-03-03 11:21:43 -080020
Eric Kunzee5e26762020-10-13 16:11:07 -070021def auto_int(x):
Kevin Cheng550ccc52021-03-03 11:21:43 -080022 """Converts hex/dec argument values to an int"""
Eric Kunzee5e26762020-10-13 16:11:07 -070023 return int(x, 0)
24
Kevin Cheng550ccc52021-03-03 11:21:43 -080025
Eric Kunzee5e26762020-10-13 16:11:07 -070026def parseArgs():
27
28 parser = argparse.ArgumentParser()
Kevin Cheng550ccc52021-03-03 11:21:43 -080029 parser.add_argument(
30 "-o", dest="output_dir", type=str, default="vtest", help="Test output directory"
31 )
Eric Kunzee5e26762020-10-13 16:11:07 -070032
Kevin Cheng550ccc52021-03-03 11:21:43 -080033 parser.add_argument(
34 "--seed",
35 dest="random_seed",
36 default=42,
37 type=int,
38 help="Random seed for test generation",
39 )
Eric Kunzee5e26762020-10-13 16:11:07 -070040
Kevin Cheng550ccc52021-03-03 11:21:43 -080041 parser.add_argument(
42 "--filter",
43 dest="filter",
44 default="",
45 type=str,
46 help="Filter operator test names by this expression",
47 )
Eric Kunzee5e26762020-10-13 16:11:07 -070048
Kevin Cheng550ccc52021-03-03 11:21:43 -080049 parser.add_argument(
50 "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
51 )
Eric Kunzee5e26762020-10-13 16:11:07 -070052
53 # Constraints on tests
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 parser.add_argument(
55 "--tensor-dim-range",
56 dest="tensor_shape_range",
57 default="1,64",
58 type=lambda x: str_to_list(x),
59 help="Min,Max range of tensor shapes",
60 )
Eric Kunzee5e26762020-10-13 16:11:07 -070061
Kevin Cheng550ccc52021-03-03 11:21:43 -080062 parser.add_argument(
63 "--max-batch-size",
64 dest="max_batch_size",
65 default=1,
66 type=int,
67 help="Maximum batch size for NHWC tests",
68 )
Eric Kunzee5e26762020-10-13 16:11:07 -070069
Kevin Cheng550ccc52021-03-03 11:21:43 -080070 parser.add_argument(
71 "--max-conv-padding",
72 dest="max_conv_padding",
73 default=1,
74 type=int,
75 help="Maximum padding for Conv tests",
76 )
Eric Kunzee5e26762020-10-13 16:11:07 -070077
Kevin Cheng550ccc52021-03-03 11:21:43 -080078 parser.add_argument(
79 "--max-conv-dilation",
80 dest="max_conv_dilation",
81 default=2,
82 type=int,
83 help="Maximum dilation for Conv tests",
84 )
Eric Kunzee5e26762020-10-13 16:11:07 -070085
Kevin Cheng550ccc52021-03-03 11:21:43 -080086 parser.add_argument(
87 "--max-conv-stride",
88 dest="max_conv_stride",
89 default=2,
90 type=int,
91 help="Maximum stride for Conv tests",
92 )
Eric Kunzee5e26762020-10-13 16:11:07 -070093
Kevin Cheng550ccc52021-03-03 11:21:43 -080094 parser.add_argument(
95 "--max-pooling-padding",
96 dest="max_pooling_padding",
97 default=1,
98 type=int,
99 help="Maximum padding for pooling tests",
100 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700101
Kevin Cheng550ccc52021-03-03 11:21:43 -0800102 parser.add_argument(
103 "--max-pooling-stride",
104 dest="max_pooling_stride",
105 default=2,
106 type=int,
107 help="Maximum stride for pooling tests",
108 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700109
Kevin Cheng550ccc52021-03-03 11:21:43 -0800110 parser.add_argument(
111 "--max-pooling-kernel",
112 dest="max_pooling_kernel",
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000113 default=3,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800114 type=int,
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000115 help="Maximum kernel for pooling tests",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
Kevin Cheng550ccc52021-03-03 11:21:43 -0800118 parser.add_argument(
119 "--num-rand-permutations",
120 dest="num_rand_permutations",
121 default=6,
122 type=int,
123 help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces",
124 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700125
126 # Targetting a specific shape/rank/dtype
Kevin Cheng550ccc52021-03-03 11:21:43 -0800127 parser.add_argument(
128 "--target-shape",
129 dest="target_shapes",
130 action="append",
131 default=[],
132 type=lambda x: str_to_list(x),
133 help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
134 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
Kevin Cheng550ccc52021-03-03 11:21:43 -0800136 parser.add_argument(
137 "--target-rank",
138 dest="target_ranks",
139 action="append",
140 default=None,
141 type=lambda x: auto_int(x),
142 help="Create tests with a particular input tensor rank",
143 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700144
Kevin Cheng550ccc52021-03-03 11:21:43 -0800145 parser.add_argument(
146 "--target-dtype",
147 dest="target_dtypes",
148 action="append",
149 default=None,
150 type=lambda x: dtype_str_to_val(x),
151 help="Create test with a particular DType (may be repeated)",
152 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
Matthew Haddon818ab902021-07-27 09:12:49 +0100154 parser.add_argument(
155 "--num-const-inputs-concat",
156 dest="num_const_inputs_concat",
157 default=0,
158 choices=[0, 1, 2, 3],
159 type=int,
160 help="Allow constant input tensors for concat operator",
161 )
162
Matthew Haddon74567092021-07-16 15:38:20 +0100163 parser.add_argument(
164 "--test-type",
165 dest="test_type",
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000166 choices=["positive", "negative", "both"],
Matthew Haddon74567092021-07-16 15:38:20 +0100167 default="positive",
168 type=str,
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000169 help="type of tests produced, positive, negative, or both",
Matthew Haddon74567092021-07-16 15:38:20 +0100170 )
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000171
172 parser.add_argument(
173 "--allow-pooling-and-conv-oversizes",
174 dest="oversize",
Jeremy Johnsonae0c1c62022-02-10 17:27:34 +0000175 action="store_true",
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000176 help="allow oversize padding, stride and kernel tests",
177 )
178
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 args = parser.parse_args()
180
181 return args
182
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184def main():
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
186 args = parseArgs()
187
188 ttg = TosaTestGen(args)
189
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000190 if args.test_type == "both":
191 testType = ["positive", "negative"]
Matthew Haddon1c00b712021-10-01 15:51:03 +0100192 else:
193 testType = [args.test_type]
Matthew Haddon74567092021-07-16 15:38:20 +0100194 results = []
Matthew Haddon1c00b712021-10-01 15:51:03 +0100195 for test_type in testType:
196 testList = []
197 for op in ttg.TOSA_OP_LIST:
198 if re.match(args.filter + ".*", op):
199 testList.extend(
200 ttg.genOpTestList(
201 op,
202 shapeFilter=args.target_shapes,
203 rankFilter=args.target_ranks,
204 dtypeFilter=args.target_dtypes,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000205 testType=test_type,
Matthew Haddon1c00b712021-10-01 15:51:03 +0100206 )
207 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100208
Matthew Haddon1c00b712021-10-01 15:51:03 +0100209 print("{} matching {} tests".format(len(testList), test_type))
210
211 testStrings = []
212 for opName, testStr, dtype, error, shapeList, testArgs in testList:
213 # Check for and skip duplicate tests
214 if testStr in testStrings:
215 continue
216 else:
217 testStrings.append(testStr)
218
219 if args.verbose:
220 print(testStr)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000221 results.append(
222 ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)
223 )
Matthew Haddon74567092021-07-16 15:38:20 +0100224
225 print(f"Done creating {len(results)} tests")
226
Eric Kunzee5e26762020-10-13 16:11:07 -0700227
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228if __name__ == "__main__":
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 exit(main())