blob: ab78b1aa6912d602fa09b44c52f99af56ee7f312 [file] [log] [blame]
# Copyright (c) 2020-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import re
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
from serializer.tosa_serializer import DTypeNames
# Used for parsing a comma-separated list of integers in a string
# to an actual list of integers
def str_to_list(in_s):
"""Converts a comma-separated list of string integers to a python list of ints"""
lst = in_s.split(",")
out_list = []
for i in lst:
out_list.append(int(i))
return out_list
def auto_int(x):
"""Converts hex/dec argument values to an int"""
return int(x, 0)
def parseArgs(argv):
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", dest="output_dir", type=str, default="vtest", help="Test output directory"
)
parser.add_argument(
"--seed",
dest="random_seed",
default=42,
type=int,
help="Random seed for test generation",
)
parser.add_argument(
"--filter",
dest="filter",
default="",
type=str,
help="Filter operator test names by this expression",
)
parser.add_argument(
"-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
)
# Constraints on tests
parser.add_argument(
"--tensor-dim-range",
dest="tensor_shape_range",
default="1,64",
type=lambda x: str_to_list(x),
help="Min,Max range of tensor shapes",
)
parser.add_argument(
"--max-batch-size",
dest="max_batch_size",
default=1,
type=int,
help="Maximum batch size for NHWC tests",
)
parser.add_argument(
"--max-conv-padding",
dest="max_conv_padding",
default=1,
type=int,
help="Maximum padding for Conv tests",
)
parser.add_argument(
"--max-conv-dilation",
dest="max_conv_dilation",
default=2,
type=int,
help="Maximum dilation for Conv tests",
)
parser.add_argument(
"--max-conv-stride",
dest="max_conv_stride",
default=2,
type=int,
help="Maximum stride for Conv tests",
)
parser.add_argument(
"--max-pooling-padding",
dest="max_pooling_padding",
default=1,
type=int,
help="Maximum padding for pooling tests",
)
parser.add_argument(
"--max-pooling-stride",
dest="max_pooling_stride",
default=2,
type=int,
help="Maximum stride for pooling tests",
)
parser.add_argument(
"--max-pooling-kernel",
dest="max_pooling_kernel",
default=3,
type=int,
help="Maximum kernel for pooling tests",
)
parser.add_argument(
"--num-rand-permutations",
dest="num_rand_permutations",
default=6,
type=int,
help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces",
)
parser.add_argument(
"--max-resize-output-dim",
dest="max_resize_output_dim",
default=1000,
type=int,
help="Upper limit on width and height output dimensions for `resize` op. Default: 1000",
)
# Targetting a specific shape/rank/dtype
parser.add_argument(
"--target-shape",
dest="target_shapes",
action="append",
default=[],
type=lambda x: str_to_list(x),
help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
)
parser.add_argument(
"--target-rank",
dest="target_ranks",
action="append",
default=None,
type=lambda x: auto_int(x),
help="Create tests with a particular input tensor rank",
)
# Used for parsing a comma-separated list of integers in a string
parser.add_argument(
"--target-dtype",
dest="target_dtypes",
action="append",
default=None,
type=lambda x: dtype_str_to_val(x),
help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
)
parser.add_argument(
"--num-const-inputs-concat",
dest="num_const_inputs_concat",
default=0,
choices=[0, 1, 2, 3],
type=int,
help="Allow constant input tensors for concat operator",
)
parser.add_argument(
"--test-type",
dest="test_type",
choices=["positive", "negative", "both"],
default="positive",
type=str,
help="type of tests produced, positive, negative, or both",
)
parser.add_argument(
"--allow-pooling-and-conv-oversizes",
dest="oversize",
action="store_true",
help="allow oversize padding, stride and kernel tests",
)
parser.add_argument(
"--zero-point",
dest="zeropoint",
default=None,
type=int,
help="set a particular zero point for all valid positive tests",
)
parser.add_argument(
"--dump-const-tensors",
dest="dump_consts",
action="store_true",
help="output const tensors as numpy files for inspection",
)
args = parser.parse_args(argv)
return args
def main(argv=None):
args = parseArgs(argv)
ttg = TosaTestGen(args)
if args.test_type == "both":
testType = ["positive", "negative"]
else:
testType = [args.test_type]
results = []
for test_type in testType:
testList = []
for op in ttg.TOSA_OP_LIST:
if re.match(args.filter + ".*", op):
testList.extend(
ttg.genOpTestList(
op,
shapeFilter=args.target_shapes,
rankFilter=args.target_ranks,
dtypeFilter=args.target_dtypes,
testType=test_type,
)
)
print("{} matching {} tests".format(len(testList), test_type))
testStrings = []
for opName, testStr, dtype, error, shapeList, testArgs in testList:
# Check for and skip duplicate tests
if testStr in testStrings:
print(f"Skipping duplicate test: {testStr}")
continue
else:
testStrings.append(testStr)
results.append(
ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)
)
print(f"Done creating {len(results)} tests")
if __name__ == "__main__":
exit(main())