blob: 0d98c17a0df388f9744ed85ae4489590bd0040ce [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001#!/usr/bin/env python3
Luke Hutton261b7b62023-01-10 14:50:31 +00002# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
5import glob
6import json
7import math
8import os
9import queue
10import re
11import sys
12import threading
13import traceback
14from datetime import datetime
15from enum import IntEnum
16from enum import unique
17
18import numpy as np
19from checker.tosa_result_checker import LogColors
20from checker.tosa_result_checker import print_color
21from checker.tosa_result_checker import set_print_in_color
22from runner.run_command import run_sh_command
23from xunit.xunit import xunit_results
24from xunit.xunit import xunit_test
25
26
27def parse_args():
28 parser = argparse.ArgumentParser()
29 parser.add_argument(
Jared Smolensb7af4612022-03-21 19:41:52 -070030 "-t",
31 "--test",
32 dest="test",
33 default=[],
34 type=str,
35 nargs="+",
36 help="Test(s) to run",
Jeremy Johnson015c3552022-02-23 12:15:03 +000037 )
38 parser.add_argument(
39 "-r",
40 "--recursive",
41 dest="recursive_tests",
42 action="store_true",
43 help="Recursively search for tests",
44 )
45 parser.add_argument(
46 "--tf-base-dir",
47 dest="tf_base_dir",
48 type=str,
49 required=True,
50 help="Tensorflow/MLIR base directory",
51 )
52 parser.add_argument(
53 "--tools-base-dir",
54 dest="tools_base_dir",
55 type=str,
56 required=True,
57 help="Reference model base directory",
58 )
59 parser.add_argument(
60 "-v", "--verbose", dest="verbose", action="count", help="Verbose run"
61 )
62 parser.add_argument(
63 "-dref",
64 "--debug-ref-model",
65 dest="debug_ref_model",
66 action="store_true",
67 help="Enable TOSA Reference model debugging",
68 )
69 parser.add_argument(
70 "--tolerance",
71 dest="tolerance",
72 default=1e-3,
73 type=float,
74 help="Comparison tolerance b value",
75 )
76 parser.add_argument(
Jerry Gea793f462023-04-11 00:05:02 +000077 "--tosa_level",
78 dest="tosa_level",
79 default="EIGHTK",
80 type=str,
81 help="A TOSA level defines operator parameter ranges that an implementation shall support."
82 "Config tosa_level for running the reference model only. Default is EIGHTK",
83 )
84 parser.add_argument(
Jeremy Johnson015c3552022-02-23 12:15:03 +000085 "--no-compiler",
86 dest="no_compiler",
87 action="store_true",
88 help="Do not run TF MLIR/tfopt/TOSA compiler. Just run TOSA Reference model",
89 )
90 parser.add_argument(
91 "--no-ref-model",
92 dest="no_ref",
93 action="store_true",
94 help="Do not run TOSA reference model, just run TF MLIR/tfopt/TOSA compiler.",
95 )
96 parser.add_argument(
97 "--valgrind",
98 dest="valgrind",
99 action="store_true",
100 help="Enable valgrind on TOSA Reference Model",
101 )
102 parser.add_argument(
103 "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
104 )
105 parser.add_argument(
106 "--no-color",
107 "--no-colour",
108 dest="no_color",
109 action="store_true",
110 help="Disable color output",
111 )
112 parser.add_argument(
113 "-f",
114 "--framework",
115 dest="framework",
116 default=[],
117 action="append",
118 help="Frameworks to test (tf, tflite)",
119 )
120 parser.add_argument(
121 "--override-exclusions",
122 dest="override_exclusions",
123 default=False,
124 action="store_true",
125 help="Ignore the framework exclusions listed in the test JSON",
126 )
127 parser.add_argument(
128 "--xunit-file",
129 dest="xunit_file",
130 type=str,
131 default="result.xml",
132 help="XUnit result output file",
133 )
134 parser.add_argument(
135 "--xunit-classname-prefix",
136 dest="xunit_classname_prefix",
137 default="TFUnitTests",
138 help="Prefix for xunit classname",
139 )
140 parser.add_argument(
141 "--hex-bool-hack",
142 dest="hex_bool_hack",
143 default=1,
144 type=int,
145 help=(
146 "Hack around bug in MLIR hex parsing for boolean types"
147 " by disabling hex encoding"
148 ),
149 )
150 parser.add_argument(
151 "--regression-mode",
152 dest="regression_mode",
153 default=False,
154 action="store_true",
155 help="Options to make the script more friendly for jenkins regressions",
156 )
157 parser.add_argument(
158 "--quantize-tolerance",
159 dest="quantize_tolerance",
160 default=0,
161 type=int,
162 help=(
163 "Tolerance when comparing TOSA reference model result"
164 " to TensorFlow Lite reference"
165 ),
166 )
167 parser.add_argument(
168 "--test-dir",
169 dest="test_dir",
170 default="",
171 help="Path to prepend to paths in test.json",
172 )
173
174 parser.add_argument(
175 "-o", "--output", dest="output_file", help="Redirect script output to a file"
176 )
177
178 args = parser.parse_args()
179
180 # No easy way to both do array append and override a default value
181 if not args.framework:
182 args.framework = ["tf", "tflite"]
183
184 # Autodetect CPU count
185 if args.jobs <= 0:
186 args.jobs = os.cpu_count()
187
188 return args
189
190
191@unique
192class TestResult(IntEnum):
193 PASS = 0
194 COMPILER_ERROR = 1
195 REF_MODEL_ERROR = 2
196 REF_MODEL_UNPREDICTABLE = 3
197 REF_MODEL_RUNTIME_ERROR = 4
198 MISMATCH = 5
199 NOT_LOWERED = 6
200 INVALID_MLIR = 7
201 INTERNAL_ERROR = 8
202 SKIPPED = 9
203
204
205TestResultErrorStr = [
206 "",
207 "Compiler error",
208 "Reference model error",
209 "Reference model unpredictable",
210 "Reference model runtime error",
211 "Mismatch",
212 "Not lowered",
213 "Invalid MLIR",
214 "Internal error",
215 "",
216]
217
218
219def parse_compiler_output(compiler_stdout, compiler_stderr):
220 # Look for "has not been lowered yet, skipped" strings in stdout
221 expr = re.compile(".* has not been lowered yet, skipped.*")
222
223 for line in compiler_stdout.splitlines():
224 if expr.match(line):
225 return TestResult.NOT_LOWERED
226
227 return TestResult.PASS
228
229
230def parse_reference_model_output(ref_model_stdout, ref_model_stderr):
231 # Look for "has not been lowered yet, skipped" strings in stdout
232 unpredictable_expr = re.compile(r".*UNPREDICTABLE.*")
233 error_expr = re.compile(".* Graph result: ERROR.*")
234 unknown_expr = re.compile(".* Unknown graph status code.*")
235
236 for line in ref_model_stderr.splitlines():
237 if unpredictable_expr.match(line):
238 return TestResult.REF_MODEL_UNPREDICTABLE
239 elif error_expr.match(line):
240 return TestResult.REF_MODEL_ERROR
241 elif unknown_expr.match(line):
242 return TestResult.REF_MODEL_RUNTIME_ERROR
243
244 return TestResult.PASS
245
246
247# write a self-contained test descriptor in json format
248def write_reference_runner_json(
249 filename,
250 tosa_filename,
251 ifm_name,
252 ifm_file,
253 ofm_name,
254 ofm_file,
255 expected_failure=False,
256):
257 """Write a json test file so that it is fairly easy to pick up the test
258 and generate commands for third party tool"""
259 test_desc = dict()
260
261 test_desc["tosa_file"] = tosa_filename
262 test_desc["ifm_name"] = ifm_name
263 test_desc["ifm_file"] = ifm_file
264 test_desc["ofm_name"] = ofm_name
265 test_desc["ofm_file"] = ofm_file
266 test_desc["expected_failure"] = expected_failure
267
268 with open(filename, "w") as f:
269 json.dump(test_desc, f, indent=" ")
270
271
272def run_test(args, test, framework):
273
274 # parse test_name from test directory path
275 test_path = test.split("/")
276 test_name = None
277 for t in test_path[::-1]:
278 if len(t) != 0:
279 test_name = t
280 break
281 if not test_name:
282 raise Exception("Could not parse test_name from {}".format(test))
283
284 print_color(LogColors.GREEN, "## Running {} test {}".format(framework, test_name))
285
286 msg = ""
287
288 try:
289 with open(os.path.join(test, "test.json"), "r") as f:
290 test_desc = json.load(f)
291 except Exception:
292 raise Exception(
293 "Could not load or parse test from {}".format(
294 os.path.join(test, "test.json")
295 )
296 )
297
298 try:
299 if not args.override_exclusions:
300 for excl in test_desc["framework_exclusions"]:
301 if excl == framework:
302 print_color(LogColors.GREEN, "Results SKIPPED")
303 return (TestResult.SKIPPED, 0.0, "")
304 except KeyError:
305 pass
306
307 tf_tools_dir = os.path.abspath(
308 "{}/bazel-bin/tensorflow/compiler/mlir".format(args.tf_base_dir)
309 )
310
311 pre_opt_filename = os.path.join(test, "test_{}.preopt.mlir".format(framework))
312 post_opt_filename = os.path.join(test, "test_{}.postopt.mlir".format(framework))
313 if args.test_dir:
314 test_path_prepend = args.test_dir
315 else:
316 test_path_prepend = test
317
318 # 1. Framework to MLIR translator command
319 if framework == "tf":
320 if test_desc["tf_model_filename"].endswith(".mlir"):
321 pre_opt_filename = test_desc["tf_model_filename"]
322 translate_mlir_cmd = []
323 else:
324 translate_mlir_cmd = [
325 os.path.join(tf_tools_dir, "tf-mlir-translate"),
326 "--graphdef-to-mlir",
327 "--tf-enable-shape-inference-on-import",
328 "--tf-output-arrays={}".format(test_desc["tf_result_name"]),
329 os.path.join(test_path_prepend, test_desc["tf_model_filename"]),
330 "-o",
331 pre_opt_filename,
332 ]
333 elif framework == "tflite":
334 if test_desc["tflite_model_filename"].endswith(".mlir"):
335 pre_opt_filename = test_desc["tflite_model_filename"]
336 translate_mlir_cmd = []
337 else:
338 translate_mlir_cmd = [
339 os.path.join(tf_tools_dir, "lite", "flatbuffer_translate"),
340 "--tflite-flatbuffer-to-mlir",
341 os.path.join(test_path_prepend, test_desc["tflite_model_filename"]),
342 "--output-arrays={}".format(test_desc["tflite_result_name"]),
343 "-o",
344 pre_opt_filename,
345 ]
346 else:
347 raise Exception("Unknown framwork: {}".format(framework))
348
349 # Any additional inputs to the translator?
350 input_tensor_prefix = "TosaInput_"
351 flatbuffer_dir = "flatbuffer-{}".format(framework)
352 mlir_opts = []
353
354 # Temporary hack: MLIR's new hex encoding of large tensors does not work for
355 # boolean types
356 # for TF hash 8e8041d594a888eb67eafa5cc62627d7e9ca8082
357 if test.endswith("_bool") and args.hex_bool_hack:
358 mlir_opts.append("--mlir-print-elementsattrs-with-hex-if-larger=-1")
359
360 try:
361 # specify input tensors if test is generated from .pb
362 if framework == "tf":
363 # Convert the shape to a mlir-friendly string
364 shapes = []
365 for curr_shape in test_desc["ifm_shape"]:
366 shape_str = ""
367 for dim in curr_shape:
368 shape_str = shape_str + str(dim) + ","
369 shapes.append(shape_str)
370
371 translate_mlir_cmd.extend(
372 ["--tf-input-arrays", ",".join(test_desc["ifm_name"])]
373 )
374 translate_mlir_cmd.extend(["--tf-input-shapes", ":".join(shapes)])
375
376 # Write the hard-coded placeholder input (reshaped as necesary) to
377 # the file that compiler specified.
378 reference_runner_ifm_name = []
379 for i in range(len(test_desc["ifm_file"])):
380
381 ifm_tensor_name = "{}{}".format(input_tensor_prefix, i)
382
383 assert test_desc["ifm_file"][i].endswith(".npy")
384 ifm_np = np.load(os.path.join(test, test_desc["ifm_file"][i]))
Jared Smolensb7af4612022-03-21 19:41:52 -0700385
386 # We sometimes encounter input shape/expected input shape mismatches
387 # due to a missing batch dimension on the input (e.g. a single 3D image).
388 #
389 # Make sure input numpy and input shape from descriptor match,
390 # expand_dims on the outer dimensions until the rank matches,
391 # then do the shape comparison.
392 while len(list(ifm_np.shape)) < len(test_desc["ifm_shape"][i]):
393 ifm_np = np.expand_dims(ifm_np, axis=0)
394
Luke Hutton714aa602023-02-08 19:45:26 +0000395 # After legalization, complex tensors are expected to be represented
396 # as a single floating point tensor of shape [?, ..., ?, 2].
397 expected_shape = test_desc["ifm_shape"][i]
398 if test.endswith("c64"):
399 expected_shape.append(2)
400
401 assert list(ifm_np.shape) == expected_shape
Jeremy Johnson015c3552022-02-23 12:15:03 +0000402
403 reference_runner_ifm_name.append(ifm_tensor_name)
404
405 except KeyError:
406 # No additional inputs. Ignore.
407 pass
408
409 tf_opt_cmd = [
410 os.path.join(tf_tools_dir, "tf-opt"),
411 "--tf-executor-to-functional-conversion",
412 "--verify-each",
413 pre_opt_filename,
414 "-o",
415 post_opt_filename,
416 ]
417
418 translate_mlir_cmd.extend(mlir_opts)
419 tf_opt_cmd.extend(mlir_opts)
420
421 compiler_cmd = [os.path.join(tf_tools_dir, "tf-opt")]
422
423 if framework == "tf":
424 compiler_cmd.append("--tf-to-tosa-pipeline")
425 elif framework == "tflite":
426 compiler_cmd.append("--tfl-to-tosa-pipeline")
427 compiler_cmd.append("--tosa-strip-quant-types")
428
429 tosa_mlir_filename = os.path.join(test, "output_{}.tosa.mlir".format(framework))
430
431 flatbuffer_dir_fullpath = os.path.join(test, flatbuffer_dir)
432
433 os.makedirs(flatbuffer_dir_fullpath, exist_ok=True)
434
435 compiler_cmd.extend(
436 [
437 "--verify-each",
438 post_opt_filename,
439 "-o",
440 tosa_mlir_filename,
441 "--tosa-serialize",
442 "--tosa-flatbuffer-filename={}".format(
443 os.path.join(flatbuffer_dir_fullpath, "{}.tosa".format(test_name))
444 ),
445 ]
446 )
447
448 if not args.no_compiler:
449 try:
450 if translate_mlir_cmd:
451 run_sh_command(translate_mlir_cmd, args.verbose, True)
452 if tf_opt_cmd:
453 run_sh_command(tf_opt_cmd, args.verbose, True)
454 except Exception as e:
455 print_color(
456 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
457 )
458 return (TestResult.INVALID_MLIR, 0.0, e)
459
460 try:
461
462 compiler_stdout, compiler_stderr = run_sh_command(
463 compiler_cmd, args.verbose, True
464 )
465 compiler_rc = parse_compiler_output(compiler_stdout, compiler_stderr)
466 if compiler_rc == TestResult.NOT_LOWERED:
467 print_color(
468 LogColors.RED,
469 "Results NOT_LOWERED {}, framework {}".format(test_name, framework),
470 )
471 return (TestResult.NOT_LOWERED, 0.0, "")
472
473 pass
474
475 except Exception as e:
476 if "same scale constraint" in str(e):
477 print_color(
478 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
479 )
480 return (TestResult.INVALID_MLIR, 0.0, e)
481 else:
482 print_color(
483 LogColors.RED, "Results COMPILER_ERROR {}: {}".format(test_name, e)
484 )
485 return (TestResult.COMPILER_ERROR, 0.0, e)
486
487 if framework == "tf":
488 try:
489 tf_result = np.load(os.path.join(test, test_desc["tf_result_npy_filename"]))
490 except KeyError:
491 assert 0, "fail to load tf result numpy"
492 elif framework == "tflite":
493 try:
494 tf_result = np.load(
495 os.path.join(test, test_desc["tflite_result_npy_filename"])
496 )
497 except KeyError:
498 assert 0, "fail to load tflite result numpy"
499
Luke Hutton261b7b62023-01-10 14:50:31 +0000500 # TOSA has no notion of complex datatypes, it represents complex values using two
501 # fp32 output tensors representing real and imaginary values. When legalizing
502 # complex operations from frameworks, these two output tensors are combined into
503 # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
504 # represents the real and imaginary parts of a complex value. This is completed
505 # by inserting reshape and concatenate TOSA operations during the legalization to
506 # maintain a one-to-one correspondance with framework outputs, thus simplifying
507 # legalization. Here tf_result should also match this format before being
508 # compared to the ref model output.
509 if tf_result.dtype == np.complex64:
510 ifm_shape = tf_result.shape + (2,)
511 tf_result = tf_result.view(np.float32)
512 tf_result = tf_result.reshape(ifm_shape)
513
Jeremy Johnson015c3552022-02-23 12:15:03 +0000514 # Generate test descriptor per flatbuffer generation
515 # Input .npy will be shared across different frameworks
516 # Output .npy will be generated in its corresponding flatbuffer
517 reference_runner_ifm_file = [
518 os.path.join("..", ifm_file) for ifm_file in test_desc["ifm_file"]
519 ]
520
521 # Check if there's any operator in output graph.
522 empty_graph = True
523 with open(tosa_mlir_filename, "r") as f:
524 for line in f:
525 if re.search('"tosa.*"', line):
526 empty_graph = False
527
528 break
529
530 # Fast-forward input tensor to output tensor if TOSA graph is empty.
531 if empty_graph:
532 reference_runner_ofm_name = reference_runner_ifm_name
533 else:
534 reference_runner_ofm_name = ["TosaOutput_0"]
535
536 write_reference_runner_json(
537 filename=os.path.join(test, flatbuffer_dir, "desc.json"),
538 tosa_filename="{}.tosa".format(test_name),
539 ifm_name=reference_runner_ifm_name,
540 ifm_file=reference_runner_ifm_file,
541 ofm_name=reference_runner_ofm_name,
542 ofm_file=["ref_model_output_0.npy"],
543 )
544
545 ref_model_cmd = [
546 os.path.join(
547 args.tools_base_dir, "build", "reference_model", "tosa_reference_model"
548 ),
Eric Kunze286f8342022-06-22 11:30:23 -0700549 "--test_desc={}".format(os.path.join(test, flatbuffer_dir, "desc.json")),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000550 ]
551
552 if args.debug_ref_model:
Eric Kunze286f8342022-06-22 11:30:23 -0700553 ref_model_cmd.extend(["-D ALL", "-l high"])
Jeremy Johnson015c3552022-02-23 12:15:03 +0000554
555 if args.valgrind:
556 ref_model_cmd = [
557 "valgrind",
558 "--show-leak-kinds=all",
559 "--log-fd=1",
560 "-q",
561 ] + ref_model_cmd
562
Jerry Gea793f462023-04-11 00:05:02 +0000563 ref_model_cmd = ref_model_cmd + ["--tosa_level={}".format(args.tosa_level)]
564
Jeremy Johnson015c3552022-02-23 12:15:03 +0000565 # Clean out any ref_model result first
566 try:
567 os.remove(os.path.join(test, flatbuffer_dir, "ref_model_*.npy"))
568 except FileNotFoundError:
569 pass
570
Jared Smolensb7af4612022-03-21 19:41:52 -0700571 if args.no_ref:
572 return (TestResult.PASS, 0.0, msg)
573
574 try:
575 ref_model_stdout, ref_model_stderr = run_sh_command(
576 ref_model_cmd, args.verbose, True
577 )
578 ref_model_rc = parse_reference_model_output(ref_model_stdout, ref_model_stderr)
579 if ref_model_rc != TestResult.PASS:
580 return (ref_model_rc, 0.0, "")
581 except Exception as e:
582 ref_model_rc = parse_reference_model_output("", str(e))
583 if ref_model_rc != TestResult.PASS:
Jeremy Johnson015c3552022-02-23 12:15:03 +0000584 print_color(
585 LogColors.RED,
Jared Smolensb7af4612022-03-21 19:41:52 -0700586 "Results {} {}: {}".format(
587 TestResultErrorStr[ref_model_rc], test_name, e
588 ),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000589 )
Jared Smolensb7af4612022-03-21 19:41:52 -0700590 return (ref_model_rc, 0.0, "")
591 print_color(
592 LogColors.RED,
593 "Results REF_MODEL_RUNTIME_ERROR {}: {}".format(test_name, e),
594 )
595 return (TestResult.REF_MODEL_RUNTIME_ERROR, 0.0, e)
Jeremy Johnson015c3552022-02-23 12:15:03 +0000596
597 if tf_result.dtype == np.float16:
598 tf_result = tf_result.astype(np.float32)
599 elif (
600 tf_result.dtype == np.uint8
601 or tf_result.dtype == np.int8
602 or tf_result.dtype == np.int16
603 or tf_result.dtype == np.int64
604 ):
605 tf_result = tf_result.astype(np.int32)
606
607 # For now, search for the output from ref_model
608 ref_model_result_files = glob.glob(
609 os.path.join(test, flatbuffer_dir, "ref_model_*.npy")
610 )
611 ref_model_result = np.load(ref_model_result_files[0])
612
613 assert (
614 tf_result.dtype == ref_model_result.dtype
615 ), "Numpy type mismatch {} != {} when comparing result".format(
616 tf_result.dtype, ref_model_result.dtype
617 )
618
619 # Size comparison
620 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
621 # >= 0, allow that special case
622 tf_result = np.squeeze(tf_result)
623 ref_model_result = np.squeeze(ref_model_result)
624
625 if np.shape(tf_result) != np.shape(ref_model_result):
626 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
627 msg = "Shapes mismatch: Reference {} vs {}".format(
628 np.shape(tf_result), np.shape(ref_model_result)
629 )
630 print(msg)
631 return (TestResult.MISMATCH, 0.0, msg)
632
633 # for quantized test, allow +-(args.quantize_tolerance) error
634 if ref_model_result.dtype == np.int32:
635 assert tf_result.dtype == np.int32
636
637 if np.all(np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance):
638 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
639 else:
640 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
641
642 tolerance = args.quantize_tolerance + 1
643 while not np.all(
644 np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance
645 ):
646 tolerance = tolerance + 1
647 if tolerance >= 10:
648 break
649
650 msg = "Result is within {} {}".format(tolerance, test)
651 print(msg)
652
653 np.set_printoptions(threshold=128)
654 print("tf_result: {}\n".format(tf_result.shape))
655 print(tf_result)
656 print("ref_model_result: {}\n".format(ref_model_result.shape))
657 print(ref_model_result)
658 # print(tf_result - ref_model_result)
659 return (TestResult.MISMATCH, tolerance, msg)
660 else:
661 if np.allclose(
662 ref_model_result, tf_result, atol=args.tolerance, equal_nan=True
663 ):
664 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
665 else:
666 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
667
668 # Many of these tests would match with a reasonable looser tolerence.
669 # Determine what would have worked.
670 tolerance = args.tolerance * 10.0
671 while not np.allclose(
672 ref_model_result, tf_result, atol=tolerance, equal_nan=True
673 ):
674 tolerance = tolerance * 10.0
675 if tolerance > 1.0e10:
676 tolerance = math.inf
677 break
678
679 msg = "Result is within {:.0e} {}".format(tolerance, test_name)
680 print(msg)
681
682 np.set_printoptions(precision=4, threshold=128)
683 print("tf_result: {}\n".format(tf_result.shape))
684 print(tf_result)
685 print("ref_model_result: {}\n".format(ref_model_result.shape))
686 print(ref_model_result)
687 # print(tf_result - ref_model_result)
688 return (TestResult.MISMATCH, tolerance, msg)
689
690 return (TestResult.PASS, args.tolerance, msg)
691
692
693def worker_thread(task_queue, args, result_queue):
694 while True:
695 try:
696 (test, framework) = task_queue.get(block=False)
697 except queue.Empty:
698 break
699
700 if test is None:
701 break
702
703 msg = ""
704 start_time = datetime.now()
705 try:
706 (rc, tolerance, msg) = run_test(args, test, framework)
707 except Exception as e:
708 print("Internal regression error: {}".format(e))
709 print(
710 "".join(
711 traceback.format_exception(
712 etype=type(e), value=e, tb=e.__traceback__
713 )
714 )
715 )
716 rc = TestResult.INTERNAL_ERROR
717 tolerance = 0.0
718
719 end_time = datetime.now()
720
721 result_queue.put((test, framework, rc, tolerance, msg, end_time - start_time))
722 task_queue.task_done()
723
724 return True
725
726
727def getTestsInDir(directory):
728 # Recursively find any tests in this directory
729 if os.path.isfile(os.path.join(directory, "test.json")):
730 return [directory]
731 elif os.path.isdir(directory):
732 test_list = []
733 for d in glob.glob(os.path.join(directory, "*")):
734 test_list.extend(getTestsInDir(d))
735 return test_list
736 else:
737 return []
738
739
740def main():
741 args = parse_args()
742
743 set_print_in_color(not args.no_color)
744
745 if args.output_file:
746 set_print_in_color(False)
747 sys.stdout = open(args.output_file, "w")
748
749 # Disable TF info messages
750 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
751
752 task_queue = queue.Queue()
753 result_queue = queue.Queue()
754
755 threads = []
756
757 # Result counters for each of the TestResult return codes
758 results = [0] * len(TestResult)
759
760 for tdir in args.test:
761
762 if args.recursive_tests:
763 tdirList = getTestsInDir(tdir)
764 else:
765 tdirList = [tdir]
766
767 for t in tdirList:
768 for f in args.framework:
769 task_queue.put((t, f))
770
771 for i in range(args.jobs):
772 t = threading.Thread(
773 target=worker_thread, args=(task_queue, args, result_queue)
774 )
775 t.setDaemon(True)
776 t.start()
777 threads.append(t)
778
779 # Run until queue is empty
780 task_queue.join()
781
782 print_color(LogColors.BOLD_WHITE, "Result summary")
783
784 result_list = []
785 while True:
786 try:
787 test, framework, rc, tol, msg, time_delta = result_queue.get(block=False)
788 except queue.Empty:
789 break
790
791 result_list.append((test, framework, rc, tol, msg, time_delta))
792 results[rc] = results[rc] + 1
793
794 xunit_result = xunit_results()
795 xunit_suite = xunit_result.create_suite(args.xunit_classname_prefix)
796
797 # Sort by test name
798 for test, framework, rc, tol, err_msg, time_delta in sorted(
799 result_list, key=lambda tup: tup[0]
800 ):
801
802 test_name = os.path.basename(test)
803 class_name = f"{args.xunit_classname_prefix}.{framework}"
804
805 xt = xunit_test(test_name, class_name)
806
807 msg = TestResultErrorStr[rc]
808
809 xt.time = str(
810 float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
811 )
812
813 if len(msg) > 0:
814 print("{} on {} {}".format(msg, framework, test))
815
816 # Add any more verbose messaging for the xml log
817 if err_msg:
818 msg = "{} {}".format(msg, err_msg)
819
820 if rc == TestResult.PASS:
821 pass
822 elif rc == TestResult.SKIPPED:
823 xt.skipped()
824 else:
825 xt.failed(msg)
826
827 xunit_suite.tests.append(xt)
828
829 result_queue.task_done()
830
831 xunit_result.write_results(args.xunit_file)
832
833 print("Totals: ", end="")
834 for result in TestResult:
835 print("{} {}, ".format(results[result], result.name.lower()), end="")
836 print()
837
838 if not args.regression_mode and (
839 results[TestResult.COMPILER_ERROR] > 0
840 or results[TestResult.REF_MODEL_ERROR] > 0
841 or results[TestResult.MISMATCH] > 0
842 ):
843 return 1
844
845 return 0
846
847
848if __name__ == "__main__":
849 exit(main())