blob: 28e436927bd748144e78c8707740a4597102528e [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001#!/usr/bin/env python3
Luke Hutton261b7b62023-01-10 14:50:31 +00002# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
5import glob
6import json
7import math
8import os
9import queue
10import re
11import sys
12import threading
13import traceback
14from datetime import datetime
15from enum import IntEnum
16from enum import unique
17
18import numpy as np
19from checker.tosa_result_checker import LogColors
20from checker.tosa_result_checker import print_color
21from checker.tosa_result_checker import set_print_in_color
22from runner.run_command import run_sh_command
23from xunit.xunit import xunit_results
24from xunit.xunit import xunit_test
25
26
27def parse_args():
28 parser = argparse.ArgumentParser()
29 parser.add_argument(
Jared Smolensb7af4612022-03-21 19:41:52 -070030 "-t",
31 "--test",
32 dest="test",
33 default=[],
34 type=str,
35 nargs="+",
36 help="Test(s) to run",
Jeremy Johnson015c3552022-02-23 12:15:03 +000037 )
38 parser.add_argument(
39 "-r",
40 "--recursive",
41 dest="recursive_tests",
42 action="store_true",
43 help="Recursively search for tests",
44 )
45 parser.add_argument(
46 "--tf-base-dir",
47 dest="tf_base_dir",
48 type=str,
49 required=True,
50 help="Tensorflow/MLIR base directory",
51 )
52 parser.add_argument(
53 "--tools-base-dir",
54 dest="tools_base_dir",
55 type=str,
56 required=True,
57 help="Reference model base directory",
58 )
59 parser.add_argument(
Tai Lya4d748b2023-03-28 22:06:56 +000060 "-p",
61 "--precise-mode",
62 dest="precise_mode",
63 action="store_true",
64 help="run in precise mode (FP64)",
65 )
66 parser.add_argument(
Jeremy Johnson015c3552022-02-23 12:15:03 +000067 "-v", "--verbose", dest="verbose", action="count", help="Verbose run"
68 )
69 parser.add_argument(
70 "-dref",
71 "--debug-ref-model",
72 dest="debug_ref_model",
73 action="store_true",
74 help="Enable TOSA Reference model debugging",
75 )
76 parser.add_argument(
77 "--tolerance",
78 dest="tolerance",
79 default=1e-3,
80 type=float,
81 help="Comparison tolerance b value",
82 )
83 parser.add_argument(
Jerry Gea793f462023-04-11 00:05:02 +000084 "--tosa_level",
85 dest="tosa_level",
86 default="EIGHTK",
87 type=str,
88 help="A TOSA level defines operator parameter ranges that an implementation shall support."
89 "Config tosa_level for running the reference model only. Default is EIGHTK",
90 )
91 parser.add_argument(
Jeremy Johnson015c3552022-02-23 12:15:03 +000092 "--no-compiler",
93 dest="no_compiler",
94 action="store_true",
95 help="Do not run TF MLIR/tfopt/TOSA compiler. Just run TOSA Reference model",
96 )
97 parser.add_argument(
98 "--no-ref-model",
99 dest="no_ref",
100 action="store_true",
101 help="Do not run TOSA reference model, just run TF MLIR/tfopt/TOSA compiler.",
102 )
103 parser.add_argument(
104 "--valgrind",
105 dest="valgrind",
106 action="store_true",
107 help="Enable valgrind on TOSA Reference Model",
108 )
109 parser.add_argument(
110 "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
111 )
112 parser.add_argument(
113 "--no-color",
114 "--no-colour",
115 dest="no_color",
116 action="store_true",
117 help="Disable color output",
118 )
119 parser.add_argument(
120 "-f",
121 "--framework",
122 dest="framework",
123 default=[],
124 action="append",
125 help="Frameworks to test (tf, tflite)",
126 )
127 parser.add_argument(
128 "--override-exclusions",
129 dest="override_exclusions",
130 default=False,
131 action="store_true",
132 help="Ignore the framework exclusions listed in the test JSON",
133 )
134 parser.add_argument(
135 "--xunit-file",
136 dest="xunit_file",
137 type=str,
138 default="result.xml",
139 help="XUnit result output file",
140 )
141 parser.add_argument(
142 "--xunit-classname-prefix",
143 dest="xunit_classname_prefix",
144 default="TFUnitTests",
145 help="Prefix for xunit classname",
146 )
147 parser.add_argument(
148 "--hex-bool-hack",
149 dest="hex_bool_hack",
150 default=1,
151 type=int,
152 help=(
153 "Hack around bug in MLIR hex parsing for boolean types"
154 " by disabling hex encoding"
155 ),
156 )
157 parser.add_argument(
158 "--regression-mode",
159 dest="regression_mode",
160 default=False,
161 action="store_true",
162 help="Options to make the script more friendly for jenkins regressions",
163 )
164 parser.add_argument(
165 "--quantize-tolerance",
166 dest="quantize_tolerance",
167 default=0,
168 type=int,
169 help=(
170 "Tolerance when comparing TOSA reference model result"
171 " to TensorFlow Lite reference"
172 ),
173 )
174 parser.add_argument(
175 "--test-dir",
176 dest="test_dir",
177 default="",
178 help="Path to prepend to paths in test.json",
179 )
180
181 parser.add_argument(
182 "-o", "--output", dest="output_file", help="Redirect script output to a file"
183 )
184
185 args = parser.parse_args()
186
187 # No easy way to both do array append and override a default value
188 if not args.framework:
189 args.framework = ["tf", "tflite"]
190
191 # Autodetect CPU count
192 if args.jobs <= 0:
193 args.jobs = os.cpu_count()
194
195 return args
196
197
198@unique
199class TestResult(IntEnum):
200 PASS = 0
201 COMPILER_ERROR = 1
202 REF_MODEL_ERROR = 2
203 REF_MODEL_UNPREDICTABLE = 3
204 REF_MODEL_RUNTIME_ERROR = 4
205 MISMATCH = 5
206 NOT_LOWERED = 6
207 INVALID_MLIR = 7
208 INTERNAL_ERROR = 8
209 SKIPPED = 9
210
211
212TestResultErrorStr = [
213 "",
214 "Compiler error",
215 "Reference model error",
216 "Reference model unpredictable",
217 "Reference model runtime error",
218 "Mismatch",
219 "Not lowered",
220 "Invalid MLIR",
221 "Internal error",
222 "",
223]
224
225
226def parse_compiler_output(compiler_stdout, compiler_stderr):
227 # Look for "has not been lowered yet, skipped" strings in stdout
228 expr = re.compile(".* has not been lowered yet, skipped.*")
229
230 for line in compiler_stdout.splitlines():
231 if expr.match(line):
232 return TestResult.NOT_LOWERED
233
234 return TestResult.PASS
235
236
237def parse_reference_model_output(ref_model_stdout, ref_model_stderr):
238 # Look for "has not been lowered yet, skipped" strings in stdout
239 unpredictable_expr = re.compile(r".*UNPREDICTABLE.*")
240 error_expr = re.compile(".* Graph result: ERROR.*")
241 unknown_expr = re.compile(".* Unknown graph status code.*")
242
243 for line in ref_model_stderr.splitlines():
244 if unpredictable_expr.match(line):
245 return TestResult.REF_MODEL_UNPREDICTABLE
246 elif error_expr.match(line):
247 return TestResult.REF_MODEL_ERROR
248 elif unknown_expr.match(line):
249 return TestResult.REF_MODEL_RUNTIME_ERROR
250
251 return TestResult.PASS
252
253
254# write a self-contained test descriptor in json format
255def write_reference_runner_json(
256 filename,
257 tosa_filename,
258 ifm_name,
259 ifm_file,
260 ofm_name,
261 ofm_file,
262 expected_failure=False,
263):
264 """Write a json test file so that it is fairly easy to pick up the test
265 and generate commands for third party tool"""
266 test_desc = dict()
267
268 test_desc["tosa_file"] = tosa_filename
269 test_desc["ifm_name"] = ifm_name
270 test_desc["ifm_file"] = ifm_file
271 test_desc["ofm_name"] = ofm_name
272 test_desc["ofm_file"] = ofm_file
273 test_desc["expected_failure"] = expected_failure
274
275 with open(filename, "w") as f:
276 json.dump(test_desc, f, indent=" ")
277
278
279def run_test(args, test, framework):
280
281 # parse test_name from test directory path
282 test_path = test.split("/")
283 test_name = None
284 for t in test_path[::-1]:
285 if len(t) != 0:
286 test_name = t
287 break
288 if not test_name:
289 raise Exception("Could not parse test_name from {}".format(test))
290
291 print_color(LogColors.GREEN, "## Running {} test {}".format(framework, test_name))
292
293 msg = ""
294
295 try:
296 with open(os.path.join(test, "test.json"), "r") as f:
297 test_desc = json.load(f)
298 except Exception:
299 raise Exception(
300 "Could not load or parse test from {}".format(
301 os.path.join(test, "test.json")
302 )
303 )
304
305 try:
306 if not args.override_exclusions:
307 for excl in test_desc["framework_exclusions"]:
308 if excl == framework:
309 print_color(LogColors.GREEN, "Results SKIPPED")
310 return (TestResult.SKIPPED, 0.0, "")
311 except KeyError:
312 pass
313
314 tf_tools_dir = os.path.abspath(
315 "{}/bazel-bin/tensorflow/compiler/mlir".format(args.tf_base_dir)
316 )
317
318 pre_opt_filename = os.path.join(test, "test_{}.preopt.mlir".format(framework))
319 post_opt_filename = os.path.join(test, "test_{}.postopt.mlir".format(framework))
320 if args.test_dir:
321 test_path_prepend = args.test_dir
322 else:
323 test_path_prepend = test
324
325 # 1. Framework to MLIR translator command
326 if framework == "tf":
327 if test_desc["tf_model_filename"].endswith(".mlir"):
328 pre_opt_filename = test_desc["tf_model_filename"]
329 translate_mlir_cmd = []
330 else:
331 translate_mlir_cmd = [
332 os.path.join(tf_tools_dir, "tf-mlir-translate"),
333 "--graphdef-to-mlir",
334 "--tf-enable-shape-inference-on-import",
335 "--tf-output-arrays={}".format(test_desc["tf_result_name"]),
336 os.path.join(test_path_prepend, test_desc["tf_model_filename"]),
337 "-o",
338 pre_opt_filename,
339 ]
340 elif framework == "tflite":
341 if test_desc["tflite_model_filename"].endswith(".mlir"):
342 pre_opt_filename = test_desc["tflite_model_filename"]
343 translate_mlir_cmd = []
344 else:
345 translate_mlir_cmd = [
346 os.path.join(tf_tools_dir, "lite", "flatbuffer_translate"),
347 "--tflite-flatbuffer-to-mlir",
348 os.path.join(test_path_prepend, test_desc["tflite_model_filename"]),
349 "--output-arrays={}".format(test_desc["tflite_result_name"]),
350 "-o",
351 pre_opt_filename,
352 ]
353 else:
354 raise Exception("Unknown framwork: {}".format(framework))
355
356 # Any additional inputs to the translator?
357 input_tensor_prefix = "TosaInput_"
358 flatbuffer_dir = "flatbuffer-{}".format(framework)
359 mlir_opts = []
360
361 # Temporary hack: MLIR's new hex encoding of large tensors does not work for
362 # boolean types
363 # for TF hash 8e8041d594a888eb67eafa5cc62627d7e9ca8082
364 if test.endswith("_bool") and args.hex_bool_hack:
365 mlir_opts.append("--mlir-print-elementsattrs-with-hex-if-larger=-1")
366
367 try:
368 # specify input tensors if test is generated from .pb
369 if framework == "tf":
370 # Convert the shape to a mlir-friendly string
371 shapes = []
372 for curr_shape in test_desc["ifm_shape"]:
373 shape_str = ""
374 for dim in curr_shape:
375 shape_str = shape_str + str(dim) + ","
376 shapes.append(shape_str)
377
378 translate_mlir_cmd.extend(
379 ["--tf-input-arrays", ",".join(test_desc["ifm_name"])]
380 )
381 translate_mlir_cmd.extend(["--tf-input-shapes", ":".join(shapes)])
382
383 # Write the hard-coded placeholder input (reshaped as necesary) to
384 # the file that compiler specified.
385 reference_runner_ifm_name = []
386 for i in range(len(test_desc["ifm_file"])):
387
388 ifm_tensor_name = "{}{}".format(input_tensor_prefix, i)
389
390 assert test_desc["ifm_file"][i].endswith(".npy")
391 ifm_np = np.load(os.path.join(test, test_desc["ifm_file"][i]))
Jared Smolensb7af4612022-03-21 19:41:52 -0700392
393 # We sometimes encounter input shape/expected input shape mismatches
394 # due to a missing batch dimension on the input (e.g. a single 3D image).
395 #
396 # Make sure input numpy and input shape from descriptor match,
397 # expand_dims on the outer dimensions until the rank matches,
398 # then do the shape comparison.
399 while len(list(ifm_np.shape)) < len(test_desc["ifm_shape"][i]):
400 ifm_np = np.expand_dims(ifm_np, axis=0)
401
Luke Hutton714aa602023-02-08 19:45:26 +0000402 # After legalization, complex tensors are expected to be represented
403 # as a single floating point tensor of shape [?, ..., ?, 2].
404 expected_shape = test_desc["ifm_shape"][i]
405 if test.endswith("c64"):
406 expected_shape.append(2)
407
408 assert list(ifm_np.shape) == expected_shape
Jeremy Johnson015c3552022-02-23 12:15:03 +0000409
410 reference_runner_ifm_name.append(ifm_tensor_name)
411
412 except KeyError:
413 # No additional inputs. Ignore.
414 pass
415
416 tf_opt_cmd = [
417 os.path.join(tf_tools_dir, "tf-opt"),
418 "--tf-executor-to-functional-conversion",
419 "--verify-each",
420 pre_opt_filename,
421 "-o",
422 post_opt_filename,
423 ]
424
425 translate_mlir_cmd.extend(mlir_opts)
426 tf_opt_cmd.extend(mlir_opts)
427
428 compiler_cmd = [os.path.join(tf_tools_dir, "tf-opt")]
429
430 if framework == "tf":
431 compiler_cmd.append("--tf-to-tosa-pipeline")
432 elif framework == "tflite":
433 compiler_cmd.append("--tfl-to-tosa-pipeline")
434 compiler_cmd.append("--tosa-strip-quant-types")
435
436 tosa_mlir_filename = os.path.join(test, "output_{}.tosa.mlir".format(framework))
437
438 flatbuffer_dir_fullpath = os.path.join(test, flatbuffer_dir)
439
440 os.makedirs(flatbuffer_dir_fullpath, exist_ok=True)
441
442 compiler_cmd.extend(
443 [
444 "--verify-each",
445 post_opt_filename,
446 "-o",
447 tosa_mlir_filename,
448 "--tosa-serialize",
449 "--tosa-flatbuffer-filename={}".format(
450 os.path.join(flatbuffer_dir_fullpath, "{}.tosa".format(test_name))
451 ),
452 ]
453 )
454
455 if not args.no_compiler:
456 try:
457 if translate_mlir_cmd:
458 run_sh_command(translate_mlir_cmd, args.verbose, True)
459 if tf_opt_cmd:
460 run_sh_command(tf_opt_cmd, args.verbose, True)
461 except Exception as e:
462 print_color(
463 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
464 )
465 return (TestResult.INVALID_MLIR, 0.0, e)
466
467 try:
468
469 compiler_stdout, compiler_stderr = run_sh_command(
470 compiler_cmd, args.verbose, True
471 )
472 compiler_rc = parse_compiler_output(compiler_stdout, compiler_stderr)
473 if compiler_rc == TestResult.NOT_LOWERED:
474 print_color(
475 LogColors.RED,
476 "Results NOT_LOWERED {}, framework {}".format(test_name, framework),
477 )
478 return (TestResult.NOT_LOWERED, 0.0, "")
479
480 pass
481
482 except Exception as e:
483 if "same scale constraint" in str(e):
484 print_color(
485 LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
486 )
487 return (TestResult.INVALID_MLIR, 0.0, e)
488 else:
489 print_color(
490 LogColors.RED, "Results COMPILER_ERROR {}: {}".format(test_name, e)
491 )
492 return (TestResult.COMPILER_ERROR, 0.0, e)
493
494 if framework == "tf":
495 try:
496 tf_result = np.load(os.path.join(test, test_desc["tf_result_npy_filename"]))
497 except KeyError:
498 assert 0, "fail to load tf result numpy"
499 elif framework == "tflite":
500 try:
501 tf_result = np.load(
502 os.path.join(test, test_desc["tflite_result_npy_filename"])
503 )
504 except KeyError:
505 assert 0, "fail to load tflite result numpy"
506
Luke Hutton261b7b62023-01-10 14:50:31 +0000507 # TOSA has no notion of complex datatypes, it represents complex values using two
508 # fp32 output tensors representing real and imaginary values. When legalizing
509 # complex operations from frameworks, these two output tensors are combined into
510 # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
511 # represents the real and imaginary parts of a complex value. This is completed
512 # by inserting reshape and concatenate TOSA operations during the legalization to
513 # maintain a one-to-one correspondance with framework outputs, thus simplifying
514 # legalization. Here tf_result should also match this format before being
515 # compared to the ref model output.
516 if tf_result.dtype == np.complex64:
517 ifm_shape = tf_result.shape + (2,)
518 tf_result = tf_result.view(np.float32)
519 tf_result = tf_result.reshape(ifm_shape)
520
Jeremy Johnson015c3552022-02-23 12:15:03 +0000521 # Generate test descriptor per flatbuffer generation
522 # Input .npy will be shared across different frameworks
523 # Output .npy will be generated in its corresponding flatbuffer
524 reference_runner_ifm_file = [
525 os.path.join("..", ifm_file) for ifm_file in test_desc["ifm_file"]
526 ]
527
528 # Check if there's any operator in output graph.
529 empty_graph = True
530 with open(tosa_mlir_filename, "r") as f:
531 for line in f:
532 if re.search('"tosa.*"', line):
533 empty_graph = False
534
535 break
536
537 # Fast-forward input tensor to output tensor if TOSA graph is empty.
538 if empty_graph:
539 reference_runner_ofm_name = reference_runner_ifm_name
540 else:
541 reference_runner_ofm_name = ["TosaOutput_0"]
542
543 write_reference_runner_json(
544 filename=os.path.join(test, flatbuffer_dir, "desc.json"),
545 tosa_filename="{}.tosa".format(test_name),
546 ifm_name=reference_runner_ifm_name,
547 ifm_file=reference_runner_ifm_file,
548 ofm_name=reference_runner_ofm_name,
549 ofm_file=["ref_model_output_0.npy"],
550 )
551
552 ref_model_cmd = [
553 os.path.join(
554 args.tools_base_dir, "build", "reference_model", "tosa_reference_model"
555 ),
Eric Kunze286f8342022-06-22 11:30:23 -0700556 "--test_desc={}".format(os.path.join(test, flatbuffer_dir, "desc.json")),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000557 ]
558
559 if args.debug_ref_model:
Eric Kunze286f8342022-06-22 11:30:23 -0700560 ref_model_cmd.extend(["-D ALL", "-l high"])
Jeremy Johnson015c3552022-02-23 12:15:03 +0000561
Tai Lya4d748b2023-03-28 22:06:56 +0000562 if args.precise_mode:
563 ref_model_cmd.extend(["--precise_mode=1"])
564
Jeremy Johnson015c3552022-02-23 12:15:03 +0000565 if args.valgrind:
566 ref_model_cmd = [
567 "valgrind",
568 "--show-leak-kinds=all",
569 "--log-fd=1",
570 "-q",
571 ] + ref_model_cmd
572
Jerry Gea793f462023-04-11 00:05:02 +0000573 ref_model_cmd = ref_model_cmd + ["--tosa_level={}".format(args.tosa_level)]
574
Jeremy Johnson015c3552022-02-23 12:15:03 +0000575 # Clean out any ref_model result first
576 try:
577 os.remove(os.path.join(test, flatbuffer_dir, "ref_model_*.npy"))
578 except FileNotFoundError:
579 pass
580
Jared Smolensb7af4612022-03-21 19:41:52 -0700581 if args.no_ref:
582 return (TestResult.PASS, 0.0, msg)
583
584 try:
585 ref_model_stdout, ref_model_stderr = run_sh_command(
586 ref_model_cmd, args.verbose, True
587 )
588 ref_model_rc = parse_reference_model_output(ref_model_stdout, ref_model_stderr)
589 if ref_model_rc != TestResult.PASS:
590 return (ref_model_rc, 0.0, "")
591 except Exception as e:
592 ref_model_rc = parse_reference_model_output("", str(e))
593 if ref_model_rc != TestResult.PASS:
Jeremy Johnson015c3552022-02-23 12:15:03 +0000594 print_color(
595 LogColors.RED,
Jared Smolensb7af4612022-03-21 19:41:52 -0700596 "Results {} {}: {}".format(
597 TestResultErrorStr[ref_model_rc], test_name, e
598 ),
Jeremy Johnson015c3552022-02-23 12:15:03 +0000599 )
Jared Smolensb7af4612022-03-21 19:41:52 -0700600 return (ref_model_rc, 0.0, "")
601 print_color(
602 LogColors.RED,
603 "Results REF_MODEL_RUNTIME_ERROR {}: {}".format(test_name, e),
604 )
605 return (TestResult.REF_MODEL_RUNTIME_ERROR, 0.0, e)
Jeremy Johnson015c3552022-02-23 12:15:03 +0000606
Tai Lya4d748b2023-03-28 22:06:56 +0000607 if args.precise_mode == 1 and (
608 tf_result.dtype == np.float16 or tf_result.dtype == np.float32
609 ):
610 tf_result = tf_result.astype(np.float64)
611 elif tf_result.dtype == np.float16:
Jeremy Johnson015c3552022-02-23 12:15:03 +0000612 tf_result = tf_result.astype(np.float32)
613 elif (
614 tf_result.dtype == np.uint8
615 or tf_result.dtype == np.int8
616 or tf_result.dtype == np.int16
617 or tf_result.dtype == np.int64
618 ):
619 tf_result = tf_result.astype(np.int32)
620
621 # For now, search for the output from ref_model
622 ref_model_result_files = glob.glob(
623 os.path.join(test, flatbuffer_dir, "ref_model_*.npy")
624 )
625 ref_model_result = np.load(ref_model_result_files[0])
626
627 assert (
628 tf_result.dtype == ref_model_result.dtype
629 ), "Numpy type mismatch {} != {} when comparing result".format(
630 tf_result.dtype, ref_model_result.dtype
631 )
632
633 # Size comparison
634 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
635 # >= 0, allow that special case
636 tf_result = np.squeeze(tf_result)
637 ref_model_result = np.squeeze(ref_model_result)
638
639 if np.shape(tf_result) != np.shape(ref_model_result):
640 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
641 msg = "Shapes mismatch: Reference {} vs {}".format(
642 np.shape(tf_result), np.shape(ref_model_result)
643 )
644 print(msg)
645 return (TestResult.MISMATCH, 0.0, msg)
646
647 # for quantized test, allow +-(args.quantize_tolerance) error
648 if ref_model_result.dtype == np.int32:
649 assert tf_result.dtype == np.int32
650
651 if np.all(np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance):
652 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
653 else:
654 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
655
656 tolerance = args.quantize_tolerance + 1
657 while not np.all(
658 np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance
659 ):
660 tolerance = tolerance + 1
661 if tolerance >= 10:
662 break
663
664 msg = "Result is within {} {}".format(tolerance, test)
665 print(msg)
666
667 np.set_printoptions(threshold=128)
668 print("tf_result: {}\n".format(tf_result.shape))
669 print(tf_result)
670 print("ref_model_result: {}\n".format(ref_model_result.shape))
671 print(ref_model_result)
672 # print(tf_result - ref_model_result)
673 return (TestResult.MISMATCH, tolerance, msg)
674 else:
675 if np.allclose(
676 ref_model_result, tf_result, atol=args.tolerance, equal_nan=True
677 ):
678 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
679 else:
680 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
681
682 # Many of these tests would match with a reasonable looser tolerence.
683 # Determine what would have worked.
684 tolerance = args.tolerance * 10.0
685 while not np.allclose(
686 ref_model_result, tf_result, atol=tolerance, equal_nan=True
687 ):
688 tolerance = tolerance * 10.0
689 if tolerance > 1.0e10:
690 tolerance = math.inf
691 break
692
693 msg = "Result is within {:.0e} {}".format(tolerance, test_name)
694 print(msg)
695
696 np.set_printoptions(precision=4, threshold=128)
697 print("tf_result: {}\n".format(tf_result.shape))
698 print(tf_result)
699 print("ref_model_result: {}\n".format(ref_model_result.shape))
700 print(ref_model_result)
701 # print(tf_result - ref_model_result)
702 return (TestResult.MISMATCH, tolerance, msg)
703
704 return (TestResult.PASS, args.tolerance, msg)
705
706
707def worker_thread(task_queue, args, result_queue):
708 while True:
709 try:
710 (test, framework) = task_queue.get(block=False)
711 except queue.Empty:
712 break
713
714 if test is None:
715 break
716
717 msg = ""
718 start_time = datetime.now()
719 try:
720 (rc, tolerance, msg) = run_test(args, test, framework)
721 except Exception as e:
722 print("Internal regression error: {}".format(e))
723 print(
724 "".join(
725 traceback.format_exception(
726 etype=type(e), value=e, tb=e.__traceback__
727 )
728 )
729 )
730 rc = TestResult.INTERNAL_ERROR
731 tolerance = 0.0
732
733 end_time = datetime.now()
734
735 result_queue.put((test, framework, rc, tolerance, msg, end_time - start_time))
736 task_queue.task_done()
737
738 return True
739
740
741def getTestsInDir(directory):
742 # Recursively find any tests in this directory
743 if os.path.isfile(os.path.join(directory, "test.json")):
744 return [directory]
745 elif os.path.isdir(directory):
746 test_list = []
747 for d in glob.glob(os.path.join(directory, "*")):
748 test_list.extend(getTestsInDir(d))
749 return test_list
750 else:
751 return []
752
753
754def main():
755 args = parse_args()
756
757 set_print_in_color(not args.no_color)
758
759 if args.output_file:
760 set_print_in_color(False)
761 sys.stdout = open(args.output_file, "w")
762
763 # Disable TF info messages
764 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
765
766 task_queue = queue.Queue()
767 result_queue = queue.Queue()
768
769 threads = []
770
771 # Result counters for each of the TestResult return codes
772 results = [0] * len(TestResult)
773
774 for tdir in args.test:
775
776 if args.recursive_tests:
777 tdirList = getTestsInDir(tdir)
778 else:
779 tdirList = [tdir]
780
781 for t in tdirList:
782 for f in args.framework:
783 task_queue.put((t, f))
784
785 for i in range(args.jobs):
786 t = threading.Thread(
787 target=worker_thread, args=(task_queue, args, result_queue)
788 )
789 t.setDaemon(True)
790 t.start()
791 threads.append(t)
792
793 # Run until queue is empty
794 task_queue.join()
795
796 print_color(LogColors.BOLD_WHITE, "Result summary")
797
798 result_list = []
799 while True:
800 try:
801 test, framework, rc, tol, msg, time_delta = result_queue.get(block=False)
802 except queue.Empty:
803 break
804
805 result_list.append((test, framework, rc, tol, msg, time_delta))
806 results[rc] = results[rc] + 1
807
808 xunit_result = xunit_results()
809 xunit_suite = xunit_result.create_suite(args.xunit_classname_prefix)
810
811 # Sort by test name
812 for test, framework, rc, tol, err_msg, time_delta in sorted(
813 result_list, key=lambda tup: tup[0]
814 ):
815
816 test_name = os.path.basename(test)
817 class_name = f"{args.xunit_classname_prefix}.{framework}"
818
819 xt = xunit_test(test_name, class_name)
820
821 msg = TestResultErrorStr[rc]
822
823 xt.time = str(
824 float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
825 )
826
827 if len(msg) > 0:
828 print("{} on {} {}".format(msg, framework, test))
829
830 # Add any more verbose messaging for the xml log
831 if err_msg:
832 msg = "{} {}".format(msg, err_msg)
833
834 if rc == TestResult.PASS:
835 pass
836 elif rc == TestResult.SKIPPED:
837 xt.skipped()
838 else:
839 xt.failed(msg)
840
841 xunit_suite.tests.append(xt)
842
843 result_queue.task_done()
844
845 xunit_result.write_results(args.xunit_file)
846
847 print("Totals: ", end="")
848 for result in TestResult:
849 print("{} {}, ".format(results[result], result.name.lower()), end="")
850 print()
851
852 if not args.regression_mode and (
853 results[TestResult.COMPILER_ERROR] > 0
854 or results[TestResult.REF_MODEL_ERROR] > 0
855 or results[TestResult.MISMATCH] > 0
856 ):
857 return 1
858
859 return 0
860
861
862if __name__ == "__main__":
863 exit(main())