blob: c55864ae48a2ffde4a8acf62cc3374b35b7c464c [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001#!/usr/bin/env python3
Luke Hutton261b7b62023-01-10 14:50:31 +00002# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
5import glob
6import json
7import math
8import os
9import queue
10import re
11import sys
12import threading
13import traceback
14from datetime import datetime
15from enum import IntEnum
16from enum import unique
17
18import numpy as np
19from checker.tosa_result_checker import LogColors
20from checker.tosa_result_checker import print_color
21from checker.tosa_result_checker import set_print_in_color
22from runner.run_command import run_sh_command
23from xunit.xunit import xunit_results
24from xunit.xunit import xunit_test
25
26
27def parse_args():
28 parser = argparse.ArgumentParser()
29 parser.add_argument(
Jared Smolensb7af4612022-03-21 19:41:52 -070030 "-t",
31 "--test",
32 dest="test",
33 default=[],
34 type=str,
35 nargs="+",
36 help="Test(s) to run",
Jeremy Johnson015c3552022-02-23 12:15:03 +000037 )
38 parser.add_argument(
39 "-r",
40 "--recursive",
41 dest="recursive_tests",
42 action="store_true",
43 help="Recursively search for tests",
44 )
45 parser.add_argument(
46 "--tf-base-dir",
47 dest="tf_base_dir",
48 type=str,
49 required=True,
50 help="Tensorflow/MLIR base directory",
51 )
52 parser.add_argument(
53 "--tools-base-dir",
54 dest="tools_base_dir",
55 type=str,
56 required=True,
57 help="Reference model base directory",
58 )
59 parser.add_argument(
60 "-v", "--verbose", dest="verbose", action="count", help="Verbose run"
61 )
62 parser.add_argument(
63 "-dref",
64 "--debug-ref-model",
65 dest="debug_ref_model",
66 action="store_true",
67 help="Enable TOSA Reference model debugging",
68 )
69 parser.add_argument(
70 "--tolerance",
71 dest="tolerance",
72 default=1e-3,
73 type=float,
74 help="Comparison tolerance b value",
75 )
76 parser.add_argument(
77 "--no-compiler",
78 dest="no_compiler",
79 action="store_true",
80 help="Do not run TF MLIR/tfopt/TOSA compiler. Just run TOSA Reference model",
81 )
82 parser.add_argument(
83 "--no-ref-model",
84 dest="no_ref",
85 action="store_true",
86 help="Do not run TOSA reference model, just run TF MLIR/tfopt/TOSA compiler.",
87 )
88 parser.add_argument(
89 "--valgrind",
90 dest="valgrind",
91 action="store_true",
92 help="Enable valgrind on TOSA Reference Model",
93 )
94 parser.add_argument(
95 "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
96 )
97 parser.add_argument(
98 "--no-color",
99 "--no-colour",
100 dest="no_color",
101 action="store_true",
102 help="Disable color output",
103 )
104 parser.add_argument(
105 "-f",
106 "--framework",
107 dest="framework",
108 default=[],
109 action="append",
110 help="Frameworks to test (tf, tflite)",
111 )
112 parser.add_argument(
113 "--override-exclusions",
114 dest="override_exclusions",
115 default=False,
116 action="store_true",
117 help="Ignore the framework exclusions listed in the test JSON",
118 )
119 parser.add_argument(
120 "--xunit-file",
121 dest="xunit_file",
122 type=str,
123 default="result.xml",
124 help="XUnit result output file",
125 )
126 parser.add_argument(
127 "--xunit-classname-prefix",
128 dest="xunit_classname_prefix",
129 default="TFUnitTests",
130 help="Prefix for xunit classname",
131 )
132 parser.add_argument(
133 "--hex-bool-hack",
134 dest="hex_bool_hack",
135 default=1,
136 type=int,
137 help=(
138 "Hack around bug in MLIR hex parsing for boolean types"
139 " by disabling hex encoding"
140 ),
141 )
142 parser.add_argument(
143 "--regression-mode",
144 dest="regression_mode",
145 default=False,
146 action="store_true",
147 help="Options to make the script more friendly for jenkins regressions",
148 )
149 parser.add_argument(
150 "--quantize-tolerance",
151 dest="quantize_tolerance",
152 default=0,
153 type=int,
154 help=(
155 "Tolerance when comparing TOSA reference model result"
156 " to TensorFlow Lite reference"
157 ),
158 )
159 parser.add_argument(
160 "--test-dir",
161 dest="test_dir",
162 default="",
163 help="Path to prepend to paths in test.json",
164 )
165
166 parser.add_argument(
167 "-o", "--output", dest="output_file", help="Redirect script output to a file"
168 )
169
170 args = parser.parse_args()
171
172 # No easy way to both do array append and override a default value
173 if not args.framework:
174 args.framework = ["tf", "tflite"]
175
176 # Autodetect CPU count
177 if args.jobs <= 0:
178 args.jobs = os.cpu_count()
179
180 return args
181
182
183@unique
184class TestResult(IntEnum):
185 PASS = 0
186 COMPILER_ERROR = 1
187 REF_MODEL_ERROR = 2
188 REF_MODEL_UNPREDICTABLE = 3
189 REF_MODEL_RUNTIME_ERROR = 4
190 MISMATCH = 5
191 NOT_LOWERED = 6
192 INVALID_MLIR = 7
193 INTERNAL_ERROR = 8
194 SKIPPED = 9
195
196
197TestResultErrorStr = [
198 "",
199 "Compiler error",
200 "Reference model error",
201 "Reference model unpredictable",
202 "Reference model runtime error",
203 "Mismatch",
204 "Not lowered",
205 "Invalid MLIR",
206 "Internal error",
207 "",
208]
209
210
211def parse_compiler_output(compiler_stdout, compiler_stderr):
212 # Look for "has not been lowered yet, skipped" strings in stdout
213 expr = re.compile(".* has not been lowered yet, skipped.*")
214
215 for line in compiler_stdout.splitlines():
216 if expr.match(line):
217 return TestResult.NOT_LOWERED
218
219 return TestResult.PASS
220
221
222def parse_reference_model_output(ref_model_stdout, ref_model_stderr):
223 # Look for "has not been lowered yet, skipped" strings in stdout
224 unpredictable_expr = re.compile(r".*UNPREDICTABLE.*")
225 error_expr = re.compile(".* Graph result: ERROR.*")
226 unknown_expr = re.compile(".* Unknown graph status code.*")
227
228 for line in ref_model_stderr.splitlines():
229 if unpredictable_expr.match(line):
230 return TestResult.REF_MODEL_UNPREDICTABLE
231 elif error_expr.match(line):
232 return TestResult.REF_MODEL_ERROR
233 elif unknown_expr.match(line):
234 return TestResult.REF_MODEL_RUNTIME_ERROR
235
236 return TestResult.PASS
237
238
239# write a self-contained test descriptor in json format
240def write_reference_runner_json(
241 filename,
242 tosa_filename,
243 ifm_name,
244 ifm_file,
245 ofm_name,
246 ofm_file,
247 expected_failure=False,
248):
249 """Write a json test file so that it is fairly easy to pick up the test
250 and generate commands for third party tool"""
251 test_desc = dict()
252
253 test_desc["tosa_file"] = tosa_filename
254 test_desc["ifm_name"] = ifm_name
255 test_desc["ifm_file"] = ifm_file
256 test_desc["ofm_name"] = ofm_name
257 test_desc["ofm_file"] = ofm_file
258 test_desc["expected_failure"] = expected_failure
259
260 with open(filename, "w") as f:
261 json.dump(test_desc, f, indent=" ")
262
263
264def run_test(args, test, framework):
265
266 # parse test_name from test directory path
267 test_path = test.split("/")
268 test_name = None
269 for t in test_path[::-1]:
270 if len(t) != 0:
271 test_name = t
272 break
273 if not test_name:
274 raise Exception("Could not parse test_name from {}".format(test))
275
276 print_color(LogColors.GREEN, "## Running {} test {}".format(framework, test_name))
277
278 msg = ""
279
280 try:
281 with open(os.path.join(test, "test.json"), "r") as f:
282 test_desc = json.load(f)
283 except Exception:
284 raise Exception(
285 "Could not load or parse test from {}".format(
286 os.path.join(test, "test.json")
287 )
288 )
289
290 try:
291 if not args.override_exclusions:
292 for excl in test_desc["framework_exclusions"]:
293 if excl == framework:
294 print_color(LogColors.GREEN, "Results SKIPPED")
295 return (TestResult.SKIPPED, 0.0, "")
296 except KeyError:
297 pass
298
299 tf_tools_dir = os.path.abspath(
300 "{}/bazel-bin/tensorflow/compiler/mlir".format(args.tf_base_dir)
301 )
302
303 pre_opt_filename = os.path.join(test, "test_{}.preopt.mlir".format(framework))
304 post_opt_filename = os.path.join(test, "test_{}.postopt.mlir".format(framework))
305 if args.test_dir:
306 test_path_prepend = args.test_dir
307 else:
308 test_path_prepend = test
309
310 # 1. Framework to MLIR translator command
311 if framework == "tf":
312 if test_desc["tf_model_filename"].endswith(".mlir"):
313 pre_opt_filename = test_desc["tf_model_filename"]
314 translate_mlir_cmd = []
315 else:
316 translate_mlir_cmd = [
317 os.path.join(tf_tools_dir, "tf-mlir-translate"),
318 "--graphdef-to-mlir",
319 "--tf-enable-shape-inference-on-import",
320 "--tf-output-arrays={}".format(test_desc["tf_result_name"]),
321 os.path.join(test_path_prepend, test_desc["tf_model_filename"]),
322 "-o",
323 pre_opt_filename,
324 ]
325 elif framework == "tflite":
326 if test_desc["tflite_model_filename"].endswith(".mlir"):
327 pre_opt_filename = test_desc["tflite_model_filename"]
328 translate_mlir_cmd = []
329 else:
330 translate_mlir_cmd = [
331 os.path.join(tf_tools_dir, "lite", "flatbuffer_translate"),
332 "--tflite-flatbuffer-to-mlir",
333 os.path.join(test_path_prepend, test_desc["tflite_model_filename"]),
334 "--output-arrays={}".format(test_desc["tflite_result_name"]),
335 "-o",
336 pre_opt_filename,
337 ]
338 else:
339 raise Exception("Unknown framwork: {}".format(framework))
340
341 # Any additional inputs to the translator?
342 input_tensor_prefix = "TosaInput_"
343 flatbuffer_dir = "flatbuffer-{}".format(framework)
344 mlir_opts = []
345
346 # Temporary hack: MLIR's new hex encoding of large tensors does not work for
347 # boolean types
348 # for TF hash 8e8041d594a888eb67eafa5cc62627d7e9ca8082
349 if test.endswith("_bool") and args.hex_bool_hack:
350 mlir_opts.append("--mlir-print-elementsattrs-with-hex-if-larger=-1")
351
352 try:
353 # specify input tensors if test is generated from .pb
354 if framework == "tf":
355 # Convert the shape to a mlir-friendly string
356 shapes = []
357 for curr_shape in test_desc["ifm_shape"]:
358 shape_str = ""
359 for dim in curr_shape:
360 shape_str = shape_str + str(dim) + ","
361 shapes.append(shape_str)
362
363 translate_mlir_cmd.extend(
364 ["--tf-input-arrays", ",".join(test_desc["ifm_name"])]
365 )
366 translate_mlir_cmd.extend(["--tf-input-shapes", ":".join(shapes)])
367
368 # Write the hard-coded placeholder input (reshaped as necesary) to
369 # the file that compiler specified.
370 reference_runner_ifm_name = []
371 for i in range(len(test_desc["ifm_file"])):
372
373 ifm_tensor_name = "{}{}".format(input_tensor_prefix, i)
374
375 assert test_desc["ifm_file"][i].endswith(".npy")
376 ifm_np = np.load(os.path.join(test, test_desc["ifm_file"][i]))
Jared Smolensb7af4612022-03-21 19:41:52 -0700377
378 # We sometimes encounter input shape/expected input shape mismatches
379 # due to a missing batch dimension on the input (e.g. a single 3D image).
380 #
381 # Make sure input numpy and input shape from descriptor match,
382 # expand_dims on the outer dimensions until the rank matches,
383 # then do the shape comparison.
384 while len(list(ifm_np.shape)) < len(test_desc["ifm_shape"][i]):
385 ifm_np = np.expand_dims(ifm_np, axis=0)
386
Jeremy Johnson015c3552022-02-23 12:15:03 +0000387 assert list(ifm_np.shape) == test_desc["ifm_shape"][i]
388
389 reference_runner_ifm_name.append(ifm_tensor_name)
390
391 except KeyError:
392 # No additional inputs. Ignore.
393 pass
394
395 tf_opt_cmd = [
396 os.path.join(tf_tools_dir, "tf-opt"),
397 "--tf-executor-to-functional-conversion",
398 "--verify-each",
399 pre_opt_filename,
400 "-o",
401 post_opt_filename,
402 ]
403
404 translate_mlir_cmd.extend(mlir_opts)
405 tf_opt_cmd.extend(mlir_opts)
406
407 compiler_cmd = [os.path.join(tf_tools_dir, "tf-opt")]
408
409 if framework == "tf":
410 compiler_cmd.append("--tf-to-tosa-pipeline")
411 elif framework == "tflite":
412 compiler_cmd.append("--tfl-to-tosa-pipeline")
413 compiler_cmd.append("--tosa-strip-quant-types")
414
415 tosa_mlir_filename = os.path.join(test, "output_{}.tosa.mlir".format(framework))
416
417 flatbuffer_dir_fullpath = os.path.join(test, flatbuffer_dir)
418
419 os.makedirs(flatbuffer_dir_fullpath, exist_ok=True)
420
421 compiler_cmd.extend(
422 [
423 "--verify-each",
424 post_opt_filename,
425 "-o",
426 tosa_mlir_filename,
427 "--tosa-serialize",
428 "--tosa-flatbuffer-filename={}".format(
429 os.path.join(flatbuffer_dir_fullpath, "{}.tosa".format(test_name))
430 ),
431 ]
432 )
433
434 if not args.no_compiler:
435 try:
436 if translate_mlir_cmd:
437 run_sh_command(translate_mlir_cmd, args.verbose, True)
438 if tf_opt_cmd:
439 run_sh_command(tf_opt_cmd, args.verbose, True)
440 except Exception as e:
441 print_color(
442 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
443 )
444 return (TestResult.INVALID_MLIR, 0.0, e)
445
446 try:
447
448 compiler_stdout, compiler_stderr = run_sh_command(
449 compiler_cmd, args.verbose, True
450 )
451 compiler_rc = parse_compiler_output(compiler_stdout, compiler_stderr)
452 if compiler_rc == TestResult.NOT_LOWERED:
453 print_color(
454 LogColors.RED,
455 "Results NOT_LOWERED {}, framework {}".format(test_name, framework),
456 )
457 return (TestResult.NOT_LOWERED, 0.0, "")
458
459 pass
460
461 except Exception as e:
462 if "same scale constraint" in str(e):
463 print_color(
464 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
465 )
466 return (TestResult.INVALID_MLIR, 0.0, e)
467 else:
468 print_color(
469 LogColors.RED, "Results COMPILER_ERROR {}: {}".format(test_name, e)
470 )
471 return (TestResult.COMPILER_ERROR, 0.0, e)
472
473 if framework == "tf":
474 try:
475 tf_result = np.load(os.path.join(test, test_desc["tf_result_npy_filename"]))
476 except KeyError:
477 assert 0, "fail to load tf result numpy"
478 elif framework == "tflite":
479 try:
480 tf_result = np.load(
481 os.path.join(test, test_desc["tflite_result_npy_filename"])
482 )
483 except KeyError:
484 assert 0, "fail to load tflite result numpy"
485
Luke Hutton261b7b62023-01-10 14:50:31 +0000486 # TOSA has no notion of complex datatypes, it represents complex values using two
487 # fp32 output tensors representing real and imaginary values. When legalizing
488 # complex operations from frameworks, these two output tensors are combined into
489 # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
490 # represents the real and imaginary parts of a complex value. This is completed
491 # by inserting reshape and concatenate TOSA operations during the legalization to
492 # maintain a one-to-one correspondance with framework outputs, thus simplifying
493 # legalization. Here tf_result should also match this format before being
494 # compared to the ref model output.
495 if tf_result.dtype == np.complex64:
496 ifm_shape = tf_result.shape + (2,)
497 tf_result = tf_result.view(np.float32)
498 tf_result = tf_result.reshape(ifm_shape)
499
Jeremy Johnson015c3552022-02-23 12:15:03 +0000500 # Generate test descriptor per flatbuffer generation
501 # Input .npy will be shared across different frameworks
502 # Output .npy will be generated in its corresponding flatbuffer
503 reference_runner_ifm_file = [
504 os.path.join("..", ifm_file) for ifm_file in test_desc["ifm_file"]
505 ]
506
507 # Check if there's any operator in output graph.
508 empty_graph = True
509 with open(tosa_mlir_filename, "r") as f:
510 for line in f:
511 if re.search('"tosa.*"', line):
512 empty_graph = False
513
514 break
515
516 # Fast-forward input tensor to output tensor if TOSA graph is empty.
517 if empty_graph:
518 reference_runner_ofm_name = reference_runner_ifm_name
519 else:
520 reference_runner_ofm_name = ["TosaOutput_0"]
521
522 write_reference_runner_json(
523 filename=os.path.join(test, flatbuffer_dir, "desc.json"),
524 tosa_filename="{}.tosa".format(test_name),
525 ifm_name=reference_runner_ifm_name,
526 ifm_file=reference_runner_ifm_file,
527 ofm_name=reference_runner_ofm_name,
528 ofm_file=["ref_model_output_0.npy"],
529 )
530
531 ref_model_cmd = [
532 os.path.join(
533 args.tools_base_dir, "build", "reference_model", "tosa_reference_model"
534 ),
Eric Kunze286f8342022-06-22 11:30:23 -0700535 "--test_desc={}".format(os.path.join(test, flatbuffer_dir, "desc.json")),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000536 ]
537
538 if args.debug_ref_model:
Eric Kunze286f8342022-06-22 11:30:23 -0700539 ref_model_cmd.extend(["-D ALL", "-l high"])
Jeremy Johnson015c3552022-02-23 12:15:03 +0000540
541 if args.valgrind:
542 ref_model_cmd = [
543 "valgrind",
544 "--show-leak-kinds=all",
545 "--log-fd=1",
546 "-q",
547 ] + ref_model_cmd
548
549 # Clean out any ref_model result first
550 try:
551 os.remove(os.path.join(test, flatbuffer_dir, "ref_model_*.npy"))
552 except FileNotFoundError:
553 pass
554
Jared Smolensb7af4612022-03-21 19:41:52 -0700555 if args.no_ref:
556 return (TestResult.PASS, 0.0, msg)
557
558 try:
559 ref_model_stdout, ref_model_stderr = run_sh_command(
560 ref_model_cmd, args.verbose, True
561 )
562 ref_model_rc = parse_reference_model_output(ref_model_stdout, ref_model_stderr)
563 if ref_model_rc != TestResult.PASS:
564 return (ref_model_rc, 0.0, "")
565 except Exception as e:
566 ref_model_rc = parse_reference_model_output("", str(e))
567 if ref_model_rc != TestResult.PASS:
Jeremy Johnson015c3552022-02-23 12:15:03 +0000568 print_color(
569 LogColors.RED,
Jared Smolensb7af4612022-03-21 19:41:52 -0700570 "Results {} {}: {}".format(
571 TestResultErrorStr[ref_model_rc], test_name, e
572 ),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000573 )
Jared Smolensb7af4612022-03-21 19:41:52 -0700574 return (ref_model_rc, 0.0, "")
575 print_color(
576 LogColors.RED,
577 "Results REF_MODEL_RUNTIME_ERROR {}: {}".format(test_name, e),
578 )
579 return (TestResult.REF_MODEL_RUNTIME_ERROR, 0.0, e)
Jeremy Johnson015c3552022-02-23 12:15:03 +0000580
581 if tf_result.dtype == np.float16:
582 tf_result = tf_result.astype(np.float32)
583 elif (
584 tf_result.dtype == np.uint8
585 or tf_result.dtype == np.int8
586 or tf_result.dtype == np.int16
587 or tf_result.dtype == np.int64
588 ):
589 tf_result = tf_result.astype(np.int32)
590
591 # For now, search for the output from ref_model
592 ref_model_result_files = glob.glob(
593 os.path.join(test, flatbuffer_dir, "ref_model_*.npy")
594 )
595 ref_model_result = np.load(ref_model_result_files[0])
596
597 assert (
598 tf_result.dtype == ref_model_result.dtype
599 ), "Numpy type mismatch {} != {} when comparing result".format(
600 tf_result.dtype, ref_model_result.dtype
601 )
602
603 # Size comparison
604 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
605 # >= 0, allow that special case
606 tf_result = np.squeeze(tf_result)
607 ref_model_result = np.squeeze(ref_model_result)
608
609 if np.shape(tf_result) != np.shape(ref_model_result):
610 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
611 msg = "Shapes mismatch: Reference {} vs {}".format(
612 np.shape(tf_result), np.shape(ref_model_result)
613 )
614 print(msg)
615 return (TestResult.MISMATCH, 0.0, msg)
616
617 # for quantized test, allow +-(args.quantize_tolerance) error
618 if ref_model_result.dtype == np.int32:
619 assert tf_result.dtype == np.int32
620
621 if np.all(np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance):
622 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
623 else:
624 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
625
626 tolerance = args.quantize_tolerance + 1
627 while not np.all(
628 np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance
629 ):
630 tolerance = tolerance + 1
631 if tolerance >= 10:
632 break
633
634 msg = "Result is within {} {}".format(tolerance, test)
635 print(msg)
636
637 np.set_printoptions(threshold=128)
638 print("tf_result: {}\n".format(tf_result.shape))
639 print(tf_result)
640 print("ref_model_result: {}\n".format(ref_model_result.shape))
641 print(ref_model_result)
642 # print(tf_result - ref_model_result)
643 return (TestResult.MISMATCH, tolerance, msg)
644 else:
645 if np.allclose(
646 ref_model_result, tf_result, atol=args.tolerance, equal_nan=True
647 ):
648 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
649 else:
650 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
651
652 # Many of these tests would match with a reasonable looser tolerence.
653 # Determine what would have worked.
654 tolerance = args.tolerance * 10.0
655 while not np.allclose(
656 ref_model_result, tf_result, atol=tolerance, equal_nan=True
657 ):
658 tolerance = tolerance * 10.0
659 if tolerance > 1.0e10:
660 tolerance = math.inf
661 break
662
663 msg = "Result is within {:.0e} {}".format(tolerance, test_name)
664 print(msg)
665
666 np.set_printoptions(precision=4, threshold=128)
667 print("tf_result: {}\n".format(tf_result.shape))
668 print(tf_result)
669 print("ref_model_result: {}\n".format(ref_model_result.shape))
670 print(ref_model_result)
671 # print(tf_result - ref_model_result)
672 return (TestResult.MISMATCH, tolerance, msg)
673
674 return (TestResult.PASS, args.tolerance, msg)
675
676
677def worker_thread(task_queue, args, result_queue):
678 while True:
679 try:
680 (test, framework) = task_queue.get(block=False)
681 except queue.Empty:
682 break
683
684 if test is None:
685 break
686
687 msg = ""
688 start_time = datetime.now()
689 try:
690 (rc, tolerance, msg) = run_test(args, test, framework)
691 except Exception as e:
692 print("Internal regression error: {}".format(e))
693 print(
694 "".join(
695 traceback.format_exception(
696 etype=type(e), value=e, tb=e.__traceback__
697 )
698 )
699 )
700 rc = TestResult.INTERNAL_ERROR
701 tolerance = 0.0
702
703 end_time = datetime.now()
704
705 result_queue.put((test, framework, rc, tolerance, msg, end_time - start_time))
706 task_queue.task_done()
707
708 return True
709
710
711def getTestsInDir(directory):
712 # Recursively find any tests in this directory
713 if os.path.isfile(os.path.join(directory, "test.json")):
714 return [directory]
715 elif os.path.isdir(directory):
716 test_list = []
717 for d in glob.glob(os.path.join(directory, "*")):
718 test_list.extend(getTestsInDir(d))
719 return test_list
720 else:
721 return []
722
723
724def main():
725 args = parse_args()
726
727 set_print_in_color(not args.no_color)
728
729 if args.output_file:
730 set_print_in_color(False)
731 sys.stdout = open(args.output_file, "w")
732
733 # Disable TF info messages
734 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
735
736 task_queue = queue.Queue()
737 result_queue = queue.Queue()
738
739 threads = []
740
741 # Result counters for each of the TestResult return codes
742 results = [0] * len(TestResult)
743
744 for tdir in args.test:
745
746 if args.recursive_tests:
747 tdirList = getTestsInDir(tdir)
748 else:
749 tdirList = [tdir]
750
751 for t in tdirList:
752 for f in args.framework:
753 task_queue.put((t, f))
754
755 for i in range(args.jobs):
756 t = threading.Thread(
757 target=worker_thread, args=(task_queue, args, result_queue)
758 )
759 t.setDaemon(True)
760 t.start()
761 threads.append(t)
762
763 # Run until queue is empty
764 task_queue.join()
765
766 print_color(LogColors.BOLD_WHITE, "Result summary")
767
768 result_list = []
769 while True:
770 try:
771 test, framework, rc, tol, msg, time_delta = result_queue.get(block=False)
772 except queue.Empty:
773 break
774
775 result_list.append((test, framework, rc, tol, msg, time_delta))
776 results[rc] = results[rc] + 1
777
778 xunit_result = xunit_results()
779 xunit_suite = xunit_result.create_suite(args.xunit_classname_prefix)
780
781 # Sort by test name
782 for test, framework, rc, tol, err_msg, time_delta in sorted(
783 result_list, key=lambda tup: tup[0]
784 ):
785
786 test_name = os.path.basename(test)
787 class_name = f"{args.xunit_classname_prefix}.{framework}"
788
789 xt = xunit_test(test_name, class_name)
790
791 msg = TestResultErrorStr[rc]
792
793 xt.time = str(
794 float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
795 )
796
797 if len(msg) > 0:
798 print("{} on {} {}".format(msg, framework, test))
799
800 # Add any more verbose messaging for the xml log
801 if err_msg:
802 msg = "{} {}".format(msg, err_msg)
803
804 if rc == TestResult.PASS:
805 pass
806 elif rc == TestResult.SKIPPED:
807 xt.skipped()
808 else:
809 xt.failed(msg)
810
811 xunit_suite.tests.append(xt)
812
813 result_queue.task_done()
814
815 xunit_result.write_results(args.xunit_file)
816
817 print("Totals: ", end="")
818 for result in TestResult:
819 print("{} {}, ".format(results[result], result.name.lower()), end="")
820 print()
821
822 if not args.regression_mode and (
823 results[TestResult.COMPILER_ERROR] > 0
824 or results[TestResult.REF_MODEL_ERROR] > 0
825 or results[TestResult.MISMATCH] > 0
826 ):
827 return 1
828
829 return 0
830
831
832if __name__ == "__main__":
833 exit(main())