blob: 68e44da940d15faf74a5149ce170a551a062316c [file] [log] [blame]
# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import re
import sys
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
from serializer.tosa_serializer import DTypeNames
OPTION_FP_VALUES_RANGE = "--fp-values-range"
# Used for parsing a comma-separated list of integers in a string
# to an actual list of integers
def str_to_list(in_s, is_float=False):
"""Converts a comma-separated list of string integers to a python list of ints"""
lst = in_s.split(",")
out_list = []
for i in lst:
val = float(i) if is_float else int(i)
out_list.append(val)
return out_list
def auto_int(x):
"""Converts hex/dec argument values to an int"""
return int(x, 0)
def parseArgs(argv):
"""Parse the command line arguments."""
if argv is None:
argv = sys.argv[1:]
if OPTION_FP_VALUES_RANGE in argv:
# Argparse fix for hyphen (minus values) in argument values
# convert "ARG VAL" into "ARG=VAL"
# Example --fp-values-range -2.0,2.0 -> --fp-values-range=-2.0,2.0
new_argv = []
idx = 0
while idx < len(argv):
arg = argv[idx]
if arg == OPTION_FP_VALUES_RANGE and idx + 1 < len(argv):
val = argv[idx + 1]
if val.startswith("-"):
arg = f"{arg}={val}"
idx += 1
new_argv.append(arg)
idx += 1
argv = new_argv
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", dest="output_dir", type=str, default="vtest", help="Test output directory"
)
parser.add_argument(
"--seed",
dest="random_seed",
default=42,
type=int,
help="Random seed for test generation",
)
parser.add_argument(
"--filter",
dest="filter",
default="",
type=str,
help="Filter operator test names by this expression",
)
parser.add_argument(
"-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
)
# Constraints on tests
parser.add_argument(
"--tensor-dim-range",
dest="tensor_shape_range",
default="1,64",
type=lambda x: str_to_list(x),
help="Min,Max range of tensor shapes",
)
parser.add_argument(
OPTION_FP_VALUES_RANGE,
dest="tensor_fp_value_range",
default="0.0,1.0",
type=lambda x: str_to_list(x, is_float=True),
help="Min,Max range of floating point tensor values",
)
parser.add_argument(
"--max-batch-size",
dest="max_batch_size",
default=1,
type=positive_integer_type,
help="Maximum batch size for NHWC tests",
)
parser.add_argument(
"--max-conv-padding",
dest="max_conv_padding",
default=1,
type=int,
help="Maximum padding for Conv tests",
)
parser.add_argument(
"--max-conv-dilation",
dest="max_conv_dilation",
default=2,
type=int,
help="Maximum dilation for Conv tests",
)
parser.add_argument(
"--max-conv-stride",
dest="max_conv_stride",
default=2,
type=int,
help="Maximum stride for Conv tests",
)
parser.add_argument(
"--max-pooling-padding",
dest="max_pooling_padding",
default=1,
type=int,
help="Maximum padding for pooling tests",
)
parser.add_argument(
"--max-pooling-stride",
dest="max_pooling_stride",
default=2,
type=int,
help="Maximum stride for pooling tests",
)
parser.add_argument(
"--max-pooling-kernel",
dest="max_pooling_kernel",
default=3,
type=int,
help="Maximum kernel for pooling tests",
)
parser.add_argument(
"--num-rand-permutations",
dest="num_rand_permutations",
default=6,
type=int,
help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces",
)
parser.add_argument(
"--max-resize-output-dim",
dest="max_resize_output_dim",
default=1000,
type=int,
help="Upper limit on width and height output dimensions for `resize` op. Default: 1000",
)
# Targeting a specific shape/rank/dtype
parser.add_argument(
"--target-shape",
dest="target_shapes",
action="append",
default=[],
type=lambda x: str_to_list(x),
help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
)
parser.add_argument(
"--target-rank",
dest="target_ranks",
action="append",
default=None,
type=lambda x: auto_int(x),
help="Create tests with a particular input tensor rank",
)
# Used for parsing a comma-separated list of integers in a string
parser.add_argument(
"--target-dtype",
dest="target_dtypes",
action="append",
default=None,
type=lambda x: dtype_str_to_val(x),
help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
)
parser.add_argument(
"--num-const-inputs-concat",
dest="num_const_inputs_concat",
default=0,
choices=[0, 1, 2, 3],
type=int,
help="Allow constant input tensors for concat operator",
)
parser.add_argument(
"--test-type",
dest="test_type",
choices=["positive", "negative", "both"],
default="positive",
type=str,
help="type of tests produced, positive, negative, or both",
)
parser.add_argument(
"--allow-pooling-and-conv-oversizes",
dest="oversize",
action="store_true",
help="allow oversize padding, stride and kernel tests",
)
parser.add_argument(
"--zero-point",
dest="zeropoint",
default=None,
type=int,
help="set a particular zero point for all valid positive tests",
)
parser.add_argument(
"--dump-const-tensors",
dest="dump_consts",
action="store_true",
help="output const tensors as numpy files for inspection",
)
args = parser.parse_args(argv)
return args
def positive_integer_type(argv_str):
value = int(argv_str)
if value <= 0:
msg = f"{argv_str} is not a valid positive integer"
raise argparse.ArgumentTypeError(msg)
return value
def main(argv=None):
args = parseArgs(argv)
ttg = TosaTestGen(args)
if args.test_type == "both":
testType = ["positive", "negative"]
else:
testType = [args.test_type]
results = []
for test_type in testType:
testList = []
for op in ttg.TOSA_OP_LIST:
if re.match(args.filter + ".*", op):
testList.extend(
ttg.genOpTestList(
op,
shapeFilter=args.target_shapes,
rankFilter=args.target_ranks,
dtypeFilter=args.target_dtypes,
testType=test_type,
)
)
print("{} matching {} tests".format(len(testList), test_type))
testStrings = []
for opName, testStr, dtype, error, shapeList, testArgs in testList:
# Check for and skip duplicate tests
if testStr in testStrings:
print(f"Skipping duplicate test: {testStr}")
continue
else:
testStrings.append(testStr)
results.append(
ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)
)
print(f"Done creating {len(results)} tests")
if __name__ == "__main__":
exit(main())