blob: 2284e35c625ecf9f1c60ae186e1363230dfd97b8 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
3# Copyright (c) 2020, ARM Limited.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import argparse
19import sys
20import re
21import os
22import subprocess
23import shlex
24import json
25import glob
26import math
27import queue
28import threading
29import traceback
30import importlib
31
32
33from enum import IntEnum, Enum, unique
34from datetime import datetime
35
36# Include the ../shared directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(os.path.join(parent_dir, '..', 'scripts'))
39sys.path.append(os.path.join(parent_dir, '..', 'scripts', 'xunit'))
40import xunit
41import tosa
42from tosa_test_gen import TosaTestGen
43from tosa_test_runner import TosaTestRunner
44
45no_color_printing = False
46#from run_tf_unit_test import LogColors, print_color, run_sh_command
47
48def parseArgs():
49
50 parser = argparse.ArgumentParser()
51 parser.add_argument('-t', '--test', dest='test', type=str, nargs='+',
52 help='Test(s) to run')
53 parser.add_argument('--seed', dest='random_seed', default=42, type=int,
54 help='Random seed for test generation')
55 parser.add_argument('--ref-model-path', dest='ref_model_path',
56 default='build/reference_model/tosa_reference_model', type=str,
57 help='Path to reference model executable')
58 parser.add_argument('--ref-debug', dest='ref_debug', default='', type=str,
59 help='Reference debug flag (low, med, high)')
60 parser.add_argument('--ref-intermediates', dest='ref_intermediates', default=0, type=int,
61 help='Reference model dumps intermediate tensors')
62 parser.add_argument('-v', '--verbose', dest='verbose', action='count',
63 help='Verbose operation')
64 parser.add_argument('-j', '--jobs', dest='jobs', type=int, default=1,
65 help='Number of parallel jobs')
66 parser.add_argument('--sut-module', '-s', dest='sut_module', type=str, nargs='+', default=['tosa_ref_run'],
67 help='System under test module to load (derives from TosaTestRunner). May be repeated')
68 parser.add_argument('--sut-module-args', dest='sut_module_args', type=str, nargs='+', default=[],
69 help='System under test module arguments. Use sutmodulename:argvalue to pass an argument. May be repeated.')
70 parser.add_argument('--xunit-file', dest='xunit_file', type=str, default='result.xml',
71 help='XUnit output file')
72
73 args = parser.parse_args()
74
75 # Autodetect CPU count
76 if args.jobs <= 0:
77 args.jobs = os.cpu_count()
78
79 return args
80
81def workerThread(task_queue, runnerList, args, result_queue):
82 while True:
83 try:
84 test = task_queue.get(block=False)
85 except queue.Empty:
86 break
87
88 if test is None:
89 break
90
91 msg = ''
92 start_time = datetime.now()
93 try:
94
95 for runnerModule, runnerArgs in runnerList:
96 if args.verbose:
97 print('Running runner {} with test {}'.format(runnerModule.__name__, test))
98 runner = runnerModule.TosaRefRunner(args, runnerArgs, test)
99 try:
100 rc = runner.runModel()
101 except Exception as e:
102 rc = TosaTestRunner.Result.INTERNAL_ERROR
103 except Exception as e:
104 print('Internal regression error: {}'.format(e))
105 print(''.join(traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)))
106 rc = TosaTestRunner.Result.INTERNAL_ERROR
107
108 end_time = datetime.now()
109
110 result_queue.put((test, rc, msg, end_time - start_time))
111 task_queue.task_done()
112
113 return True
114
115def loadRefModules(args):
116 # Returns a tuple of (runner_module, [argument list])
117 runnerList = []
118 for r in args.sut_module:
119 if args.verbose:
120 print('Loading module {}'.format(r))
121
122 runner = importlib.import_module(r)
123
124 # Look for arguments associated with this runner
125 runnerArgPrefix = '{}:'.format(r)
126 runnerArgList = []
127 for a in args.sut_module_args:
128 if a.startswith(runnerArgPrefix):
129 runnerArgList.append(a[len(runnerArgPrefix):])
130 runnerList.append((runner, runnerArgList))
131
132 return runnerList
133
134def main():
135 args = parseArgs()
136
137 runnerList = loadRefModules(args)
138
139 threads = []
140 taskQueue = queue.Queue()
141 resultQueue = queue.Queue()
142
143 for t in args.test:
144 taskQueue.put((t))
145
146 print('Running {} tests '.format(taskQueue.qsize()))
147
148 for i in range(args.jobs):
149 t = threading.Thread(target=workerThread, args=(taskQueue, runnerList, args, resultQueue))
150 t.setDaemon(True)
151 t.start()
152 threads.append(t)
153
154 taskQueue.join()
155
156 resultList = []
157 results = [0] * len(TosaTestRunner.Result)
158
159 while True:
160 try:
161 test, rc, msg, time_delta = resultQueue.get(block=False)
162 except queue.Empty:
163 break
164
165 resultList.append((test, rc, msg, time_delta))
166 results[rc] = results[rc] + 1
167
168 xunit_result = xunit.xunit_results('Regressions')
169 xunit_suite = xunit_result.create_suite('Unit tests')
170
171 # Sort by test name
172 for test, rc, msg, time_delta in sorted(resultList, key=lambda tup: tup[0]):
173 test_name = test
174 xt = xunit.xunit_test(test_name, 'reference')
175
176 xt.time = str(float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6))
177
178 if rc == TosaTestRunner.Result.EXPECTED_PASS or rc == TosaTestRunner.Result.EXPECTED_FAILURE:
179 if args.verbose:
180 print('{} {}'.format(rc.name, test_name))
181 else:
182 xt.failed(msg)
183 print('{} {}'.format(rc.name, test_name))
184
185 xunit_suite.tests.append(xt)
186 resultQueue.task_done()
187
188 xunit_result.write_results(args.xunit_file)
189
190 print('Totals: ', end='')
191 for result in TosaTestRunner.Result:
192 print('{} {}, '.format(results[result], result.name.lower()), end ='')
193 print()
194
195 return 0
196
197if __name__ == '__main__':
198 exit(main())