Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame^] | 1 | #!/usr/bin/env python3 |
| 2 | |
| 3 | # Copyright (c) 2020, ARM Limited. |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | # you may not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | |
| 17 | |
| 18 | import argparse |
| 19 | import sys |
| 20 | import re |
| 21 | import os |
| 22 | import subprocess |
| 23 | import shlex |
| 24 | import json |
| 25 | import glob |
| 26 | import math |
| 27 | import queue |
| 28 | import threading |
| 29 | import traceback |
| 30 | |
| 31 | |
| 32 | from enum import IntEnum, Enum, unique |
| 33 | from datetime import datetime |
| 34 | |
| 35 | # Include the ../shared directory in PYTHONPATH |
| 36 | parent_dir = os.path.dirname(os.path.realpath(__file__)) |
| 37 | sys.path.append(os.path.join(parent_dir, '..', 'scripts')) |
| 38 | sys.path.append(os.path.join(parent_dir, '..', 'scripts', 'xunit')) |
| 39 | import xunit |
| 40 | from tosa_serializer import * |
| 41 | from tosa_test_gen import TosaTestGen |
| 42 | import tosa |
| 43 | |
| 44 | # Used for parsing a comma-separated list of integers in a string |
| 45 | # to an actual list of integers |
| 46 | def str_to_list(in_s): |
| 47 | '''Converts a comma-separated list of string integers to a python list of ints''' |
| 48 | lst = in_s.split(',') |
| 49 | out_list = [] |
| 50 | for i in lst: |
| 51 | out_list.append(int(i)) |
| 52 | return out_list |
| 53 | |
| 54 | def auto_int(x): |
| 55 | '''Converts hex/dec argument values to an int''' |
| 56 | return int(x, 0) |
| 57 | |
| 58 | def parseArgs(): |
| 59 | |
| 60 | parser = argparse.ArgumentParser() |
| 61 | parser.add_argument('-o', dest='output_dir', type=str, default='vtest', |
| 62 | help='Test output directory') |
| 63 | |
| 64 | parser.add_argument('--seed', dest='random_seed', default=42, type=int, |
| 65 | help='Random seed for test generation') |
| 66 | |
| 67 | parser.add_argument('--filter', dest='filter', default='', type=str, |
| 68 | help='Filter operator test names by this expression') |
| 69 | |
| 70 | parser.add_argument('-v', '--verbose', dest='verbose', action='count', |
| 71 | help='Verbose operation') |
| 72 | |
| 73 | # Constraints on tests |
| 74 | parser.add_argument('--tensor-dim-range', dest='tensor_shape_range', default='1,64', |
| 75 | type=lambda x: str_to_list(x), |
| 76 | help='Min,Max range of tensor shapes') |
| 77 | |
| 78 | parser.add_argument('--max-batch-size', dest='max_batch_size', default=1, type=int, |
| 79 | help='Maximum batch size for NHWC tests') |
| 80 | |
| 81 | parser.add_argument('--max-conv-padding', dest='max_conv_padding', default=1, type=int, |
| 82 | help='Maximum padding for Conv tests') |
| 83 | |
| 84 | parser.add_argument('--max-conv-dilation', dest='max_conv_dilation', default=2, type=int, |
| 85 | help='Maximum dilation for Conv tests') |
| 86 | |
| 87 | parser.add_argument('--max-conv-stride', dest='max_conv_stride', default=2, type=int, |
| 88 | help='Maximum stride for Conv tests') |
| 89 | |
| 90 | parser.add_argument('--max-pooling-padding', dest='max_pooling_padding', default=1, type=int, |
| 91 | help='Maximum padding for pooling tests') |
| 92 | |
| 93 | parser.add_argument('--max-pooling-stride', dest='max_pooling_stride', default=2, type=int, |
| 94 | help='Maximum stride for pooling tests') |
| 95 | |
| 96 | parser.add_argument('--max-pooling-kernel', dest='max_pooling_kernel', default=2, type=int, |
| 97 | help='Maximum padding for pooling tests') |
| 98 | |
| 99 | parser.add_argument('--num-rand-permutations', dest='num_rand_permutations', default=6, type=int, |
| 100 | help='Number of random permutations for a given shape/rank for randomly-sampled parameter spaces') |
| 101 | |
| 102 | # Targetting a specific shape/rank/dtype |
| 103 | parser.add_argument('--target-shape', dest='target_shapes', action='append', default=[], type=lambda x: str_to_list(x), |
| 104 | help='Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)') |
| 105 | |
| 106 | parser.add_argument('--target-rank', dest='target_ranks', action='append', default=None, type=lambda x: auto_int(x), |
| 107 | help='Create tests with a particular input tensor rank') |
| 108 | |
| 109 | parser.add_argument('--target-dtype', dest='target_dtypes', action='append', default=None, type=lambda x: dtype_str_to_val(x), |
| 110 | help='Create test with a particular DType (may be repeated)') |
| 111 | |
| 112 | args = parser.parse_args() |
| 113 | |
| 114 | return args |
| 115 | |
| 116 | def main(): |
| 117 | |
| 118 | |
| 119 | args = parseArgs() |
| 120 | |
| 121 | ttg = TosaTestGen(args) |
| 122 | |
| 123 | testList = [] |
| 124 | for op in ttg.TOSA_OP_LIST: |
| 125 | if re.match(args.filter + '.*', op): |
| 126 | testList.extend(ttg.genOpTestList(op, shapeFilter=args.target_shapes, rankFilter=args.target_ranks, dtypeFilter=args.target_dtypes)) |
| 127 | |
| 128 | print('{} matching tests'.format(len(testList))) |
| 129 | for opName, testStr, dtype, shapeList, testArgs in testList: |
| 130 | print(testStr) |
| 131 | ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs) |
| 132 | print('Done creating {} tests'.format(len(testList))) |
| 133 | |
| 134 | |
| 135 | if __name__ == '__main__': |
| 136 | exit(main()) |