blob: 71723ae19de52ffa61c75aee36764e557bea3731 [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001#!/usr/bin/env python3
Luke Hutton261b7b62023-01-10 14:50:31 +00002# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
5import glob
6import json
7import math
8import os
9import queue
10import re
11import sys
12import threading
13import traceback
14from datetime import datetime
15from enum import IntEnum
16from enum import unique
17
18import numpy as np
19from checker.tosa_result_checker import LogColors
20from checker.tosa_result_checker import print_color
21from checker.tosa_result_checker import set_print_in_color
22from runner.run_command import run_sh_command
23from xunit.xunit import xunit_results
24from xunit.xunit import xunit_test
25
26
27def parse_args():
28 parser = argparse.ArgumentParser()
29 parser.add_argument(
Jared Smolensb7af4612022-03-21 19:41:52 -070030 "-t",
31 "--test",
32 dest="test",
33 default=[],
34 type=str,
35 nargs="+",
36 help="Test(s) to run",
Jeremy Johnson015c3552022-02-23 12:15:03 +000037 )
38 parser.add_argument(
39 "-r",
40 "--recursive",
41 dest="recursive_tests",
42 action="store_true",
43 help="Recursively search for tests",
44 )
45 parser.add_argument(
46 "--tf-base-dir",
47 dest="tf_base_dir",
48 type=str,
49 required=True,
50 help="Tensorflow/MLIR base directory",
51 )
52 parser.add_argument(
53 "--tools-base-dir",
54 dest="tools_base_dir",
55 type=str,
56 required=True,
57 help="Reference model base directory",
58 )
59 parser.add_argument(
60 "-v", "--verbose", dest="verbose", action="count", help="Verbose run"
61 )
62 parser.add_argument(
63 "-dref",
64 "--debug-ref-model",
65 dest="debug_ref_model",
66 action="store_true",
67 help="Enable TOSA Reference model debugging",
68 )
69 parser.add_argument(
70 "--tolerance",
71 dest="tolerance",
72 default=1e-3,
73 type=float,
74 help="Comparison tolerance b value",
75 )
76 parser.add_argument(
77 "--no-compiler",
78 dest="no_compiler",
79 action="store_true",
80 help="Do not run TF MLIR/tfopt/TOSA compiler. Just run TOSA Reference model",
81 )
82 parser.add_argument(
83 "--no-ref-model",
84 dest="no_ref",
85 action="store_true",
86 help="Do not run TOSA reference model, just run TF MLIR/tfopt/TOSA compiler.",
87 )
88 parser.add_argument(
89 "--valgrind",
90 dest="valgrind",
91 action="store_true",
92 help="Enable valgrind on TOSA Reference Model",
93 )
94 parser.add_argument(
95 "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
96 )
97 parser.add_argument(
98 "--no-color",
99 "--no-colour",
100 dest="no_color",
101 action="store_true",
102 help="Disable color output",
103 )
104 parser.add_argument(
105 "-f",
106 "--framework",
107 dest="framework",
108 default=[],
109 action="append",
110 help="Frameworks to test (tf, tflite)",
111 )
112 parser.add_argument(
113 "--override-exclusions",
114 dest="override_exclusions",
115 default=False,
116 action="store_true",
117 help="Ignore the framework exclusions listed in the test JSON",
118 )
119 parser.add_argument(
120 "--xunit-file",
121 dest="xunit_file",
122 type=str,
123 default="result.xml",
124 help="XUnit result output file",
125 )
126 parser.add_argument(
127 "--xunit-classname-prefix",
128 dest="xunit_classname_prefix",
129 default="TFUnitTests",
130 help="Prefix for xunit classname",
131 )
132 parser.add_argument(
133 "--hex-bool-hack",
134 dest="hex_bool_hack",
135 default=1,
136 type=int,
137 help=(
138 "Hack around bug in MLIR hex parsing for boolean types"
139 " by disabling hex encoding"
140 ),
141 )
142 parser.add_argument(
143 "--regression-mode",
144 dest="regression_mode",
145 default=False,
146 action="store_true",
147 help="Options to make the script more friendly for jenkins regressions",
148 )
149 parser.add_argument(
150 "--quantize-tolerance",
151 dest="quantize_tolerance",
152 default=0,
153 type=int,
154 help=(
155 "Tolerance when comparing TOSA reference model result"
156 " to TensorFlow Lite reference"
157 ),
158 )
159 parser.add_argument(
160 "--test-dir",
161 dest="test_dir",
162 default="",
163 help="Path to prepend to paths in test.json",
164 )
165
166 parser.add_argument(
167 "-o", "--output", dest="output_file", help="Redirect script output to a file"
168 )
169
170 args = parser.parse_args()
171
172 # No easy way to both do array append and override a default value
173 if not args.framework:
174 args.framework = ["tf", "tflite"]
175
176 # Autodetect CPU count
177 if args.jobs <= 0:
178 args.jobs = os.cpu_count()
179
180 return args
181
182
183@unique
184class TestResult(IntEnum):
185 PASS = 0
186 COMPILER_ERROR = 1
187 REF_MODEL_ERROR = 2
188 REF_MODEL_UNPREDICTABLE = 3
189 REF_MODEL_RUNTIME_ERROR = 4
190 MISMATCH = 5
191 NOT_LOWERED = 6
192 INVALID_MLIR = 7
193 INTERNAL_ERROR = 8
194 SKIPPED = 9
195
196
197TestResultErrorStr = [
198 "",
199 "Compiler error",
200 "Reference model error",
201 "Reference model unpredictable",
202 "Reference model runtime error",
203 "Mismatch",
204 "Not lowered",
205 "Invalid MLIR",
206 "Internal error",
207 "",
208]
209
210
211def parse_compiler_output(compiler_stdout, compiler_stderr):
212 # Look for "has not been lowered yet, skipped" strings in stdout
213 expr = re.compile(".* has not been lowered yet, skipped.*")
214
215 for line in compiler_stdout.splitlines():
216 if expr.match(line):
217 return TestResult.NOT_LOWERED
218
219 return TestResult.PASS
220
221
222def parse_reference_model_output(ref_model_stdout, ref_model_stderr):
223 # Look for "has not been lowered yet, skipped" strings in stdout
224 unpredictable_expr = re.compile(r".*UNPREDICTABLE.*")
225 error_expr = re.compile(".* Graph result: ERROR.*")
226 unknown_expr = re.compile(".* Unknown graph status code.*")
227
228 for line in ref_model_stderr.splitlines():
229 if unpredictable_expr.match(line):
230 return TestResult.REF_MODEL_UNPREDICTABLE
231 elif error_expr.match(line):
232 return TestResult.REF_MODEL_ERROR
233 elif unknown_expr.match(line):
234 return TestResult.REF_MODEL_RUNTIME_ERROR
235
236 return TestResult.PASS
237
238
239# write a self-contained test descriptor in json format
240def write_reference_runner_json(
241 filename,
242 tosa_filename,
243 ifm_name,
244 ifm_file,
245 ofm_name,
246 ofm_file,
247 expected_failure=False,
248):
249 """Write a json test file so that it is fairly easy to pick up the test
250 and generate commands for third party tool"""
251 test_desc = dict()
252
253 test_desc["tosa_file"] = tosa_filename
254 test_desc["ifm_name"] = ifm_name
255 test_desc["ifm_file"] = ifm_file
256 test_desc["ofm_name"] = ofm_name
257 test_desc["ofm_file"] = ofm_file
258 test_desc["expected_failure"] = expected_failure
259
260 with open(filename, "w") as f:
261 json.dump(test_desc, f, indent=" ")
262
263
264def run_test(args, test, framework):
265
266 # parse test_name from test directory path
267 test_path = test.split("/")
268 test_name = None
269 for t in test_path[::-1]:
270 if len(t) != 0:
271 test_name = t
272 break
273 if not test_name:
274 raise Exception("Could not parse test_name from {}".format(test))
275
276 print_color(LogColors.GREEN, "## Running {} test {}".format(framework, test_name))
277
278 msg = ""
279
280 try:
281 with open(os.path.join(test, "test.json"), "r") as f:
282 test_desc = json.load(f)
283 except Exception:
284 raise Exception(
285 "Could not load or parse test from {}".format(
286 os.path.join(test, "test.json")
287 )
288 )
289
290 try:
291 if not args.override_exclusions:
292 for excl in test_desc["framework_exclusions"]:
293 if excl == framework:
294 print_color(LogColors.GREEN, "Results SKIPPED")
295 return (TestResult.SKIPPED, 0.0, "")
296 except KeyError:
297 pass
298
299 tf_tools_dir = os.path.abspath(
300 "{}/bazel-bin/tensorflow/compiler/mlir".format(args.tf_base_dir)
301 )
302
303 pre_opt_filename = os.path.join(test, "test_{}.preopt.mlir".format(framework))
304 post_opt_filename = os.path.join(test, "test_{}.postopt.mlir".format(framework))
305 if args.test_dir:
306 test_path_prepend = args.test_dir
307 else:
308 test_path_prepend = test
309
310 # 1. Framework to MLIR translator command
311 if framework == "tf":
312 if test_desc["tf_model_filename"].endswith(".mlir"):
313 pre_opt_filename = test_desc["tf_model_filename"]
314 translate_mlir_cmd = []
315 else:
316 translate_mlir_cmd = [
317 os.path.join(tf_tools_dir, "tf-mlir-translate"),
318 "--graphdef-to-mlir",
319 "--tf-enable-shape-inference-on-import",
320 "--tf-output-arrays={}".format(test_desc["tf_result_name"]),
321 os.path.join(test_path_prepend, test_desc["tf_model_filename"]),
322 "-o",
323 pre_opt_filename,
324 ]
325 elif framework == "tflite":
326 if test_desc["tflite_model_filename"].endswith(".mlir"):
327 pre_opt_filename = test_desc["tflite_model_filename"]
328 translate_mlir_cmd = []
329 else:
330 translate_mlir_cmd = [
331 os.path.join(tf_tools_dir, "lite", "flatbuffer_translate"),
332 "--tflite-flatbuffer-to-mlir",
333 os.path.join(test_path_prepend, test_desc["tflite_model_filename"]),
334 "--output-arrays={}".format(test_desc["tflite_result_name"]),
335 "-o",
336 pre_opt_filename,
337 ]
338 else:
339 raise Exception("Unknown framwork: {}".format(framework))
340
341 # Any additional inputs to the translator?
342 input_tensor_prefix = "TosaInput_"
343 flatbuffer_dir = "flatbuffer-{}".format(framework)
344 mlir_opts = []
345
346 # Temporary hack: MLIR's new hex encoding of large tensors does not work for
347 # boolean types
348 # for TF hash 8e8041d594a888eb67eafa5cc62627d7e9ca8082
349 if test.endswith("_bool") and args.hex_bool_hack:
350 mlir_opts.append("--mlir-print-elementsattrs-with-hex-if-larger=-1")
351
352 try:
353 # specify input tensors if test is generated from .pb
354 if framework == "tf":
355 # Convert the shape to a mlir-friendly string
356 shapes = []
357 for curr_shape in test_desc["ifm_shape"]:
358 shape_str = ""
359 for dim in curr_shape:
360 shape_str = shape_str + str(dim) + ","
361 shapes.append(shape_str)
362
363 translate_mlir_cmd.extend(
364 ["--tf-input-arrays", ",".join(test_desc["ifm_name"])]
365 )
366 translate_mlir_cmd.extend(["--tf-input-shapes", ":".join(shapes)])
367
368 # Write the hard-coded placeholder input (reshaped as necesary) to
369 # the file that compiler specified.
370 reference_runner_ifm_name = []
371 for i in range(len(test_desc["ifm_file"])):
372
373 ifm_tensor_name = "{}{}".format(input_tensor_prefix, i)
374
375 assert test_desc["ifm_file"][i].endswith(".npy")
376 ifm_np = np.load(os.path.join(test, test_desc["ifm_file"][i]))
Jared Smolensb7af4612022-03-21 19:41:52 -0700377
378 # We sometimes encounter input shape/expected input shape mismatches
379 # due to a missing batch dimension on the input (e.g. a single 3D image).
380 #
381 # Make sure input numpy and input shape from descriptor match,
382 # expand_dims on the outer dimensions until the rank matches,
383 # then do the shape comparison.
384 while len(list(ifm_np.shape)) < len(test_desc["ifm_shape"][i]):
385 ifm_np = np.expand_dims(ifm_np, axis=0)
386
Luke Hutton714aa602023-02-08 19:45:26 +0000387 # After legalization, complex tensors are expected to be represented
388 # as a single floating point tensor of shape [?, ..., ?, 2].
389 expected_shape = test_desc["ifm_shape"][i]
390 if test.endswith("c64"):
391 expected_shape.append(2)
392
393 assert list(ifm_np.shape) == expected_shape
Jeremy Johnson015c3552022-02-23 12:15:03 +0000394
395 reference_runner_ifm_name.append(ifm_tensor_name)
396
397 except KeyError:
398 # No additional inputs. Ignore.
399 pass
400
401 tf_opt_cmd = [
402 os.path.join(tf_tools_dir, "tf-opt"),
403 "--tf-executor-to-functional-conversion",
404 "--verify-each",
405 pre_opt_filename,
406 "-o",
407 post_opt_filename,
408 ]
409
410 translate_mlir_cmd.extend(mlir_opts)
411 tf_opt_cmd.extend(mlir_opts)
412
413 compiler_cmd = [os.path.join(tf_tools_dir, "tf-opt")]
414
415 if framework == "tf":
416 compiler_cmd.append("--tf-to-tosa-pipeline")
417 elif framework == "tflite":
418 compiler_cmd.append("--tfl-to-tosa-pipeline")
419 compiler_cmd.append("--tosa-strip-quant-types")
420
421 tosa_mlir_filename = os.path.join(test, "output_{}.tosa.mlir".format(framework))
422
423 flatbuffer_dir_fullpath = os.path.join(test, flatbuffer_dir)
424
425 os.makedirs(flatbuffer_dir_fullpath, exist_ok=True)
426
427 compiler_cmd.extend(
428 [
429 "--verify-each",
430 post_opt_filename,
431 "-o",
432 tosa_mlir_filename,
433 "--tosa-serialize",
434 "--tosa-flatbuffer-filename={}".format(
435 os.path.join(flatbuffer_dir_fullpath, "{}.tosa".format(test_name))
436 ),
437 ]
438 )
439
440 if not args.no_compiler:
441 try:
442 if translate_mlir_cmd:
443 run_sh_command(translate_mlir_cmd, args.verbose, True)
444 if tf_opt_cmd:
445 run_sh_command(tf_opt_cmd, args.verbose, True)
446 except Exception as e:
447 print_color(
448 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
449 )
450 return (TestResult.INVALID_MLIR, 0.0, e)
451
452 try:
453
454 compiler_stdout, compiler_stderr = run_sh_command(
455 compiler_cmd, args.verbose, True
456 )
457 compiler_rc = parse_compiler_output(compiler_stdout, compiler_stderr)
458 if compiler_rc == TestResult.NOT_LOWERED:
459 print_color(
460 LogColors.RED,
461 "Results NOT_LOWERED {}, framework {}".format(test_name, framework),
462 )
463 return (TestResult.NOT_LOWERED, 0.0, "")
464
465 pass
466
467 except Exception as e:
468 if "same scale constraint" in str(e):
469 print_color(
470 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
471 )
472 return (TestResult.INVALID_MLIR, 0.0, e)
473 else:
474 print_color(
475 LogColors.RED, "Results COMPILER_ERROR {}: {}".format(test_name, e)
476 )
477 return (TestResult.COMPILER_ERROR, 0.0, e)
478
479 if framework == "tf":
480 try:
481 tf_result = np.load(os.path.join(test, test_desc["tf_result_npy_filename"]))
482 except KeyError:
483 assert 0, "fail to load tf result numpy"
484 elif framework == "tflite":
485 try:
486 tf_result = np.load(
487 os.path.join(test, test_desc["tflite_result_npy_filename"])
488 )
489 except KeyError:
490 assert 0, "fail to load tflite result numpy"
491
Luke Hutton261b7b62023-01-10 14:50:31 +0000492 # TOSA has no notion of complex datatypes, it represents complex values using two
493 # fp32 output tensors representing real and imaginary values. When legalizing
494 # complex operations from frameworks, these two output tensors are combined into
495 # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
496 # represents the real and imaginary parts of a complex value. This is completed
497 # by inserting reshape and concatenate TOSA operations during the legalization to
498 # maintain a one-to-one correspondance with framework outputs, thus simplifying
499 # legalization. Here tf_result should also match this format before being
500 # compared to the ref model output.
501 if tf_result.dtype == np.complex64:
502 ifm_shape = tf_result.shape + (2,)
503 tf_result = tf_result.view(np.float32)
504 tf_result = tf_result.reshape(ifm_shape)
505
Jeremy Johnson015c3552022-02-23 12:15:03 +0000506 # Generate test descriptor per flatbuffer generation
507 # Input .npy will be shared across different frameworks
508 # Output .npy will be generated in its corresponding flatbuffer
509 reference_runner_ifm_file = [
510 os.path.join("..", ifm_file) for ifm_file in test_desc["ifm_file"]
511 ]
512
513 # Check if there's any operator in output graph.
514 empty_graph = True
515 with open(tosa_mlir_filename, "r") as f:
516 for line in f:
517 if re.search('"tosa.*"', line):
518 empty_graph = False
519
520 break
521
522 # Fast-forward input tensor to output tensor if TOSA graph is empty.
523 if empty_graph:
524 reference_runner_ofm_name = reference_runner_ifm_name
525 else:
526 reference_runner_ofm_name = ["TosaOutput_0"]
527
528 write_reference_runner_json(
529 filename=os.path.join(test, flatbuffer_dir, "desc.json"),
530 tosa_filename="{}.tosa".format(test_name),
531 ifm_name=reference_runner_ifm_name,
532 ifm_file=reference_runner_ifm_file,
533 ofm_name=reference_runner_ofm_name,
534 ofm_file=["ref_model_output_0.npy"],
535 )
536
537 ref_model_cmd = [
538 os.path.join(
539 args.tools_base_dir, "build", "reference_model", "tosa_reference_model"
540 ),
Eric Kunze286f8342022-06-22 11:30:23 -0700541 "--test_desc={}".format(os.path.join(test, flatbuffer_dir, "desc.json")),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000542 ]
543
544 if args.debug_ref_model:
Eric Kunze286f8342022-06-22 11:30:23 -0700545 ref_model_cmd.extend(["-D ALL", "-l high"])
Jeremy Johnson015c3552022-02-23 12:15:03 +0000546
547 if args.valgrind:
548 ref_model_cmd = [
549 "valgrind",
550 "--show-leak-kinds=all",
551 "--log-fd=1",
552 "-q",
553 ] + ref_model_cmd
554
555 # Clean out any ref_model result first
556 try:
557 os.remove(os.path.join(test, flatbuffer_dir, "ref_model_*.npy"))
558 except FileNotFoundError:
559 pass
560
Jared Smolensb7af4612022-03-21 19:41:52 -0700561 if args.no_ref:
562 return (TestResult.PASS, 0.0, msg)
563
564 try:
565 ref_model_stdout, ref_model_stderr = run_sh_command(
566 ref_model_cmd, args.verbose, True
567 )
568 ref_model_rc = parse_reference_model_output(ref_model_stdout, ref_model_stderr)
569 if ref_model_rc != TestResult.PASS:
570 return (ref_model_rc, 0.0, "")
571 except Exception as e:
572 ref_model_rc = parse_reference_model_output("", str(e))
573 if ref_model_rc != TestResult.PASS:
Jeremy Johnson015c3552022-02-23 12:15:03 +0000574 print_color(
575 LogColors.RED,
Jared Smolensb7af4612022-03-21 19:41:52 -0700576 "Results {} {}: {}".format(
577 TestResultErrorStr[ref_model_rc], test_name, e
578 ),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000579 )
Jared Smolensb7af4612022-03-21 19:41:52 -0700580 return (ref_model_rc, 0.0, "")
581 print_color(
582 LogColors.RED,
583 "Results REF_MODEL_RUNTIME_ERROR {}: {}".format(test_name, e),
584 )
585 return (TestResult.REF_MODEL_RUNTIME_ERROR, 0.0, e)
Jeremy Johnson015c3552022-02-23 12:15:03 +0000586
587 if tf_result.dtype == np.float16:
588 tf_result = tf_result.astype(np.float32)
589 elif (
590 tf_result.dtype == np.uint8
591 or tf_result.dtype == np.int8
592 or tf_result.dtype == np.int16
593 or tf_result.dtype == np.int64
594 ):
595 tf_result = tf_result.astype(np.int32)
596
597 # For now, search for the output from ref_model
598 ref_model_result_files = glob.glob(
599 os.path.join(test, flatbuffer_dir, "ref_model_*.npy")
600 )
601 ref_model_result = np.load(ref_model_result_files[0])
602
603 assert (
604 tf_result.dtype == ref_model_result.dtype
605 ), "Numpy type mismatch {} != {} when comparing result".format(
606 tf_result.dtype, ref_model_result.dtype
607 )
608
609 # Size comparison
610 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
611 # >= 0, allow that special case
612 tf_result = np.squeeze(tf_result)
613 ref_model_result = np.squeeze(ref_model_result)
614
615 if np.shape(tf_result) != np.shape(ref_model_result):
616 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
617 msg = "Shapes mismatch: Reference {} vs {}".format(
618 np.shape(tf_result), np.shape(ref_model_result)
619 )
620 print(msg)
621 return (TestResult.MISMATCH, 0.0, msg)
622
623 # for quantized test, allow +-(args.quantize_tolerance) error
624 if ref_model_result.dtype == np.int32:
625 assert tf_result.dtype == np.int32
626
627 if np.all(np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance):
628 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
629 else:
630 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
631
632 tolerance = args.quantize_tolerance + 1
633 while not np.all(
634 np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance
635 ):
636 tolerance = tolerance + 1
637 if tolerance >= 10:
638 break
639
640 msg = "Result is within {} {}".format(tolerance, test)
641 print(msg)
642
643 np.set_printoptions(threshold=128)
644 print("tf_result: {}\n".format(tf_result.shape))
645 print(tf_result)
646 print("ref_model_result: {}\n".format(ref_model_result.shape))
647 print(ref_model_result)
648 # print(tf_result - ref_model_result)
649 return (TestResult.MISMATCH, tolerance, msg)
650 else:
651 if np.allclose(
652 ref_model_result, tf_result, atol=args.tolerance, equal_nan=True
653 ):
654 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
655 else:
656 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
657
658 # Many of these tests would match with a reasonable looser tolerence.
659 # Determine what would have worked.
660 tolerance = args.tolerance * 10.0
661 while not np.allclose(
662 ref_model_result, tf_result, atol=tolerance, equal_nan=True
663 ):
664 tolerance = tolerance * 10.0
665 if tolerance > 1.0e10:
666 tolerance = math.inf
667 break
668
669 msg = "Result is within {:.0e} {}".format(tolerance, test_name)
670 print(msg)
671
672 np.set_printoptions(precision=4, threshold=128)
673 print("tf_result: {}\n".format(tf_result.shape))
674 print(tf_result)
675 print("ref_model_result: {}\n".format(ref_model_result.shape))
676 print(ref_model_result)
677 # print(tf_result - ref_model_result)
678 return (TestResult.MISMATCH, tolerance, msg)
679
680 return (TestResult.PASS, args.tolerance, msg)
681
682
683def worker_thread(task_queue, args, result_queue):
684 while True:
685 try:
686 (test, framework) = task_queue.get(block=False)
687 except queue.Empty:
688 break
689
690 if test is None:
691 break
692
693 msg = ""
694 start_time = datetime.now()
695 try:
696 (rc, tolerance, msg) = run_test(args, test, framework)
697 except Exception as e:
698 print("Internal regression error: {}".format(e))
699 print(
700 "".join(
701 traceback.format_exception(
702 etype=type(e), value=e, tb=e.__traceback__
703 )
704 )
705 )
706 rc = TestResult.INTERNAL_ERROR
707 tolerance = 0.0
708
709 end_time = datetime.now()
710
711 result_queue.put((test, framework, rc, tolerance, msg, end_time - start_time))
712 task_queue.task_done()
713
714 return True
715
716
717def getTestsInDir(directory):
718 # Recursively find any tests in this directory
719 if os.path.isfile(os.path.join(directory, "test.json")):
720 return [directory]
721 elif os.path.isdir(directory):
722 test_list = []
723 for d in glob.glob(os.path.join(directory, "*")):
724 test_list.extend(getTestsInDir(d))
725 return test_list
726 else:
727 return []
728
729
730def main():
731 args = parse_args()
732
733 set_print_in_color(not args.no_color)
734
735 if args.output_file:
736 set_print_in_color(False)
737 sys.stdout = open(args.output_file, "w")
738
739 # Disable TF info messages
740 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
741
742 task_queue = queue.Queue()
743 result_queue = queue.Queue()
744
745 threads = []
746
747 # Result counters for each of the TestResult return codes
748 results = [0] * len(TestResult)
749
750 for tdir in args.test:
751
752 if args.recursive_tests:
753 tdirList = getTestsInDir(tdir)
754 else:
755 tdirList = [tdir]
756
757 for t in tdirList:
758 for f in args.framework:
759 task_queue.put((t, f))
760
761 for i in range(args.jobs):
762 t = threading.Thread(
763 target=worker_thread, args=(task_queue, args, result_queue)
764 )
765 t.setDaemon(True)
766 t.start()
767 threads.append(t)
768
769 # Run until queue is empty
770 task_queue.join()
771
772 print_color(LogColors.BOLD_WHITE, "Result summary")
773
774 result_list = []
775 while True:
776 try:
777 test, framework, rc, tol, msg, time_delta = result_queue.get(block=False)
778 except queue.Empty:
779 break
780
781 result_list.append((test, framework, rc, tol, msg, time_delta))
782 results[rc] = results[rc] + 1
783
784 xunit_result = xunit_results()
785 xunit_suite = xunit_result.create_suite(args.xunit_classname_prefix)
786
787 # Sort by test name
788 for test, framework, rc, tol, err_msg, time_delta in sorted(
789 result_list, key=lambda tup: tup[0]
790 ):
791
792 test_name = os.path.basename(test)
793 class_name = f"{args.xunit_classname_prefix}.{framework}"
794
795 xt = xunit_test(test_name, class_name)
796
797 msg = TestResultErrorStr[rc]
798
799 xt.time = str(
800 float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
801 )
802
803 if len(msg) > 0:
804 print("{} on {} {}".format(msg, framework, test))
805
806 # Add any more verbose messaging for the xml log
807 if err_msg:
808 msg = "{} {}".format(msg, err_msg)
809
810 if rc == TestResult.PASS:
811 pass
812 elif rc == TestResult.SKIPPED:
813 xt.skipped()
814 else:
815 xt.failed(msg)
816
817 xunit_suite.tests.append(xt)
818
819 result_queue.task_done()
820
821 xunit_result.write_results(args.xunit_file)
822
823 print("Totals: ", end="")
824 for result in TestResult:
825 print("{} {}, ".format(results[result], result.name.lower()), end="")
826 print()
827
828 if not args.regression_mode and (
829 results[TestResult.COMPILER_ERROR] > 0
830 or results[TestResult.REF_MODEL_ERROR] > 0
831 or results[TestResult.MISMATCH] > 0
832 ):
833 return 1
834
835 return 0
836
837
838if __name__ == "__main__":
839 exit(main())