blob: b7a76b6b6a58004c88d4bbc22113fd4da15550b0 [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
2# Copyright (c) 2020-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import argparse
5import os
6from enum import Enum
7from enum import IntEnum
8from enum import unique
9from pathlib import Path
10
11import numpy as np
James Ward24dbc422022-10-19 12:20:31 +010012from generator.tosa_utils import float32_is_valid_bfloat16
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000013
14##################################
Jeremy Johnson015c3552022-02-23 12:15:03 +000015color_printing = True
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000016
17
18@unique
19class LogColors(Enum):
20 """Shell escape sequence colors for logging."""
21
22 NONE = "\u001b[0m"
23 GREEN = "\u001b[32;1m"
24 RED = "\u001b[31;1m"
25 YELLOW = "\u001b[33;1m"
26 BOLD_WHITE = "\u001b[1m"
27
28
Jeremy Johnson015c3552022-02-23 12:15:03 +000029def set_print_in_color(enabled):
30 """Set color printing to enabled or disabled."""
31 global color_printing
32 color_printing = enabled
33
34
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000035def print_color(color, msg):
36 """Print color status messages if enabled."""
Jeremy Johnson015c3552022-02-23 12:15:03 +000037 global color_printing
38 if not color_printing:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000039 print(msg)
40 else:
41 print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
42
43
44@unique
45class TestResult(IntEnum):
46 """Test result values."""
47
48 # Note: PASS must be 0 for command line return success
49 PASS = 0
50 MISSING_FILE = 1
51 INCORRECT_FORMAT = 2
52 MISMATCH = 3
53 INTERNAL_ERROR = 4
54
55
56TestResultErrorStr = [
57 "",
58 "Missing file",
59 "Incorrect format",
60 "Mismatch",
61 "Internal error",
62]
63##################################
64
65
66def test_check(
James Ward24dbc422022-10-19 12:20:31 +010067 reference,
68 result,
69 test_name="test",
70 quantize_tolerance=0,
71 float_tolerance=1e-3,
72 misc_checks=[],
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000073):
74 """Check if the result is the same as the expected reference."""
75 if not os.path.isfile(reference):
76 print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
77 msg = "Missing reference file: {}".format(reference)
78 return (TestResult.MISSING_FILE, 0.0, msg)
79 if not os.path.isfile(result):
80 print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
81 msg = "Missing result file: {}".format(result)
82 return (TestResult.MISSING_FILE, 0.0, msg)
83
84 try:
85 test_result = np.load(result)
86 except Exception as e:
87 print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
88 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e)
89 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
90 try:
91 reference_result = np.load(reference)
92 except Exception as e:
93 print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
94 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
95 reference, e
96 )
97 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
98
99 # Type comparison
100 if test_result.dtype != reference_result.dtype:
101 print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
102 msg = "Mismatch results type: Expected {}, got {}".format(
103 reference_result.dtype, test_result.dtype
104 )
105 return (TestResult.MISMATCH, 0.0, msg)
106
107 # Size comparison
108 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
109 # >= 0, allow that special case
110 test_result = np.squeeze(test_result)
111 reference_result = np.squeeze(reference_result)
112
113 if np.shape(test_result) != np.shape(reference_result):
114 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
115 msg = "Shapes mismatch: Reference {} vs {}".format(
116 np.shape(test_result), np.shape(reference_result)
117 )
118 return (TestResult.MISMATCH, 0.0, msg)
119
James Ward24dbc422022-10-19 12:20:31 +0100120 # Perform miscellaneous checks
121 if "bf16" in misc_checks:
122 # Ensure floats are valid bfloat16 values
123 test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
124 ref_res_is_bf16 = all(
125 [float32_is_valid_bfloat16(f) for f in reference_result.flat]
126 )
127 if not (test_res_is_bf16 and ref_res_is_bf16):
128 msg = (
129 "All output values must be valid bfloat16. "
130 "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
131 )
132 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
133
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000134 # for quantized test, allow +-(quantize_tolerance) error
135 if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
136
137 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
138 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
139 return (TestResult.PASS, 0.0, "")
140 else:
141 tolerance = quantize_tolerance + 1
142 while not np.all(
143 np.absolute(reference_result - test_result) <= quantize_tolerance
144 ):
145 tolerance = tolerance + 1
146 if tolerance > 10:
147 break
148
149 if tolerance > 10:
150 msg = "Integer result does not match and is greater than 10 difference"
151 else:
152 msg = (
153 "Integer result does not match but is within {} difference".format(
154 tolerance
155 )
156 )
157 # Fall-through to below to add failure values
158
159 elif reference_result.dtype == bool:
160 assert test_result.dtype == bool
161 # All boolean values must match, xor will show up differences
162 test = np.array_equal(reference_result, test_result)
163 if np.all(test):
164 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
165 return (TestResult.PASS, 0.0, "")
166 msg = "Boolean result does not match"
167 tolerance = 0.0
168 # Fall-through to below to add failure values
169
James Ward8b390432022-08-12 20:48:56 +0100170 # TODO: update for fp16 tolerance
171 elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000172 tolerance = float_tolerance
173 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
174 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
175 return (TestResult.PASS, tolerance, "")
176 msg = "Float result does not match within tolerance of {}".format(tolerance)
177 # Fall-through to below to add failure values
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000178 else:
179 print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
180 msg = "Unsupported results type: {}".format(reference_result.dtype)
181 return (TestResult.MISMATCH, 0.0, msg)
182
183 # Fall-through for mismatch failure to add values to msg
184 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
185 np.set_printoptions(threshold=128)
186 msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result)
187 msg = "{}\nreference_result: {}\n{}".format(
188 msg, reference_result.shape, reference_result
189 )
190 return (TestResult.MISMATCH, tolerance, msg)
191
192
193def main(argv=None):
194 """Check that the supplied reference and result files are the same."""
195 parser = argparse.ArgumentParser()
196 parser.add_argument(
197 "reference_path", type=Path, help="the path to the reference file to test"
198 )
199 parser.add_argument(
200 "result_path", type=Path, help="the path to the result file to test"
201 )
202 args = parser.parse_args(argv)
203 ref_path = args.reference_path
204 res_path = args.result_path
205
206 result, tolerance, msg = test_check(ref_path, res_path)
207 if result != TestResult.PASS:
208 print(msg)
209
210 return result
211
212
213if __name__ == "__main__":
214 exit(main())