blob: 19eb2f49436dddd90583d5fe317ffd7997843e9e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
3# Copyright (c) 2020, ARM Limited.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import argparse
19import sys
20import re
21import os
22import subprocess
23import shlex
24import json
25import glob
26import math
27import queue
28import threading
29import traceback
30
31
32from enum import IntEnum, Enum, unique
33from datetime import datetime
34
35# Include the ../shared directory in PYTHONPATH
36parent_dir = os.path.dirname(os.path.realpath(__file__))
37sys.path.append(os.path.join(parent_dir, '..', 'scripts'))
38sys.path.append(os.path.join(parent_dir, '..', 'scripts', 'xunit'))
39import xunit
40from tosa_serializer import *
41from tosa_test_gen import TosaTestGen
42import tosa
43
44# Used for parsing a comma-separated list of integers in a string
45# to an actual list of integers
46def str_to_list(in_s):
47 '''Converts a comma-separated list of string integers to a python list of ints'''
48 lst = in_s.split(',')
49 out_list = []
50 for i in lst:
51 out_list.append(int(i))
52 return out_list
53
54def auto_int(x):
55 '''Converts hex/dec argument values to an int'''
56 return int(x, 0)
57
58def parseArgs():
59
60 parser = argparse.ArgumentParser()
61 parser.add_argument('-o', dest='output_dir', type=str, default='vtest',
62 help='Test output directory')
63
64 parser.add_argument('--seed', dest='random_seed', default=42, type=int,
65 help='Random seed for test generation')
66
67 parser.add_argument('--filter', dest='filter', default='', type=str,
68 help='Filter operator test names by this expression')
69
70 parser.add_argument('-v', '--verbose', dest='verbose', action='count',
71 help='Verbose operation')
72
73 # Constraints on tests
74 parser.add_argument('--tensor-dim-range', dest='tensor_shape_range', default='1,64',
75 type=lambda x: str_to_list(x),
76 help='Min,Max range of tensor shapes')
77
78 parser.add_argument('--max-batch-size', dest='max_batch_size', default=1, type=int,
79 help='Maximum batch size for NHWC tests')
80
81 parser.add_argument('--max-conv-padding', dest='max_conv_padding', default=1, type=int,
82 help='Maximum padding for Conv tests')
83
84 parser.add_argument('--max-conv-dilation', dest='max_conv_dilation', default=2, type=int,
85 help='Maximum dilation for Conv tests')
86
87 parser.add_argument('--max-conv-stride', dest='max_conv_stride', default=2, type=int,
88 help='Maximum stride for Conv tests')
89
90 parser.add_argument('--max-pooling-padding', dest='max_pooling_padding', default=1, type=int,
91 help='Maximum padding for pooling tests')
92
93 parser.add_argument('--max-pooling-stride', dest='max_pooling_stride', default=2, type=int,
94 help='Maximum stride for pooling tests')
95
96 parser.add_argument('--max-pooling-kernel', dest='max_pooling_kernel', default=2, type=int,
97 help='Maximum padding for pooling tests')
98
99 parser.add_argument('--num-rand-permutations', dest='num_rand_permutations', default=6, type=int,
100 help='Number of random permutations for a given shape/rank for randomly-sampled parameter spaces')
101
102 # Targetting a specific shape/rank/dtype
103 parser.add_argument('--target-shape', dest='target_shapes', action='append', default=[], type=lambda x: str_to_list(x),
104 help='Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)')
105
106 parser.add_argument('--target-rank', dest='target_ranks', action='append', default=None, type=lambda x: auto_int(x),
107 help='Create tests with a particular input tensor rank')
108
109 parser.add_argument('--target-dtype', dest='target_dtypes', action='append', default=None, type=lambda x: dtype_str_to_val(x),
110 help='Create test with a particular DType (may be repeated)')
111
112 args = parser.parse_args()
113
114 return args
115
116def main():
117
118
119 args = parseArgs()
120
121 ttg = TosaTestGen(args)
122
123 testList = []
124 for op in ttg.TOSA_OP_LIST:
125 if re.match(args.filter + '.*', op):
126 testList.extend(ttg.genOpTestList(op, shapeFilter=args.target_shapes, rankFilter=args.target_ranks, dtypeFilter=args.target_dtypes))
127
128 print('{} matching tests'.format(len(testList)))
129 for opName, testStr, dtype, shapeList, testArgs in testList:
130 print(testStr)
131 ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs)
132 print('Done creating {} tests'.format(len(testList)))
133
134
135if __name__ == '__main__':
136 exit(main())