blob: 1169a9554e3e7af9f75efca4f5ac1a2d3faa7faf [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
2# Copyright (c) 2020-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import argparse
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00005from enum import Enum
6from enum import IntEnum
7from enum import unique
8from pathlib import Path
9
10import numpy as np
James Ward24dbc422022-10-19 12:20:31 +010011from generator.tosa_utils import float32_is_valid_bfloat16
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000012
13##################################
Jeremy Johnson015c3552022-02-23 12:15:03 +000014color_printing = True
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000015
16
17@unique
18class LogColors(Enum):
19 """Shell escape sequence colors for logging."""
20
21 NONE = "\u001b[0m"
22 GREEN = "\u001b[32;1m"
23 RED = "\u001b[31;1m"
24 YELLOW = "\u001b[33;1m"
25 BOLD_WHITE = "\u001b[1m"
26
27
Jeremy Johnson015c3552022-02-23 12:15:03 +000028def set_print_in_color(enabled):
29 """Set color printing to enabled or disabled."""
30 global color_printing
31 color_printing = enabled
32
33
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000034def print_color(color, msg):
35 """Print color status messages if enabled."""
Jeremy Johnson015c3552022-02-23 12:15:03 +000036 global color_printing
37 if not color_printing:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000038 print(msg)
39 else:
40 print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
41
42
43@unique
44class TestResult(IntEnum):
45 """Test result values."""
46
47 # Note: PASS must be 0 for command line return success
48 PASS = 0
49 MISSING_FILE = 1
50 INCORRECT_FORMAT = 2
51 MISMATCH = 3
52 INTERNAL_ERROR = 4
53
54
55TestResultErrorStr = [
56 "",
57 "Missing file",
58 "Incorrect format",
59 "Mismatch",
60 "Internal error",
61]
62##################################
63
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010064DEFAULT_FP_TOLERANCE = 1e-3
65
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000066
67def test_check(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010068 reference_path,
69 result_path,
James Ward24dbc422022-10-19 12:20:31 +010070 test_name="test",
71 quantize_tolerance=0,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010072 float_tolerance=DEFAULT_FP_TOLERANCE,
James Ward24dbc422022-10-19 12:20:31 +010073 misc_checks=[],
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000074):
75 """Check if the result is the same as the expected reference."""
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010076 if not reference_path.is_file():
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000077 print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010078 msg = "Missing reference file: {}".format(reference_path)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000079 return (TestResult.MISSING_FILE, 0.0, msg)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010080 if not result_path.is_file():
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000081 print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010082 msg = "Missing result file: {}".format(result_path)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000083 return (TestResult.MISSING_FILE, 0.0, msg)
84
85 try:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010086 test_result = np.load(result_path)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000087 except Exception as e:
88 print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010089 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
90 result_path, e
91 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000092 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
93 try:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010094 reference_result = np.load(reference_path)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000095 except Exception as e:
96 print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
97 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010098 reference_path, e
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000099 )
100 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
101
102 # Type comparison
103 if test_result.dtype != reference_result.dtype:
104 print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
105 msg = "Mismatch results type: Expected {}, got {}".format(
106 reference_result.dtype, test_result.dtype
107 )
108 return (TestResult.MISMATCH, 0.0, msg)
109
110 # Size comparison
111 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
112 # >= 0, allow that special case
113 test_result = np.squeeze(test_result)
114 reference_result = np.squeeze(reference_result)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100115 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000116
117 if np.shape(test_result) != np.shape(reference_result):
118 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
119 msg = "Shapes mismatch: Reference {} vs {}".format(
120 np.shape(test_result), np.shape(reference_result)
121 )
122 return (TestResult.MISMATCH, 0.0, msg)
123
James Ward24dbc422022-10-19 12:20:31 +0100124 # Perform miscellaneous checks
125 if "bf16" in misc_checks:
126 # Ensure floats are valid bfloat16 values
127 test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
128 ref_res_is_bf16 = all(
129 [float32_is_valid_bfloat16(f) for f in reference_result.flat]
130 )
131 if not (test_res_is_bf16 and ref_res_is_bf16):
132 msg = (
133 "All output values must be valid bfloat16. "
134 "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
135 )
136 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
137
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000138 # for quantized test, allow +-(quantize_tolerance) error
139 if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
140
141 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
142 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
143 return (TestResult.PASS, 0.0, "")
144 else:
145 tolerance = quantize_tolerance + 1
146 while not np.all(
147 np.absolute(reference_result - test_result) <= quantize_tolerance
148 ):
149 tolerance = tolerance + 1
150 if tolerance > 10:
151 break
152
153 if tolerance > 10:
154 msg = "Integer result does not match and is greater than 10 difference"
155 else:
156 msg = (
157 "Integer result does not match but is within {} difference".format(
158 tolerance
159 )
160 )
161 # Fall-through to below to add failure values
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000163
164 elif reference_result.dtype == bool:
165 assert test_result.dtype == bool
166 # All boolean values must match, xor will show up differences
167 test = np.array_equal(reference_result, test_result)
168 if np.all(test):
169 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
170 return (TestResult.PASS, 0.0, "")
171 msg = "Boolean result does not match"
172 tolerance = 0.0
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100173 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000174 # Fall-through to below to add failure values
175
James Ward8b390432022-08-12 20:48:56 +0100176 # TODO: update for fp16 tolerance
177 elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000178 tolerance = float_tolerance
179 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
180 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
181 return (TestResult.PASS, tolerance, "")
182 msg = "Float result does not match within tolerance of {}".format(tolerance)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100183 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000184 # Fall-through to below to add failure values
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000185 else:
186 print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
187 msg = "Unsupported results type: {}".format(reference_result.dtype)
188 return (TestResult.MISMATCH, 0.0, msg)
189
190 # Fall-through for mismatch failure to add values to msg
191 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100192 np.set_printoptions(threshold=128, edgeitems=2)
193
194 if difference is not None:
195 tolerance_needed = np.amax(np.absolute(difference))
196 msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)
197
198 msg = "{}\n>> reference_result: {}\n{}".format(
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000199 msg, reference_result.shape, reference_result
200 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100201 msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)
202
203 if difference is not None:
204 msg = "{}\n!! difference_result: \n{}".format(msg, difference)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000205 return (TestResult.MISMATCH, tolerance, msg)
206
207
208def main(argv=None):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100209 """Check that the supplied reference and result files have the same contents."""
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000210 parser = argparse.ArgumentParser()
211 parser.add_argument(
212 "reference_path", type=Path, help="the path to the reference file to test"
213 )
214 parser.add_argument(
215 "result_path", type=Path, help="the path to the result file to test"
216 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100217 parser.add_argument(
218 "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
219 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000220 args = parser.parse_args(argv)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000221
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100222 result, tolerance, msg = test_check(
223 args.reference_path, args.result_path, float_tolerance=args.fp_tolerance
224 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000225 if result != TestResult.PASS:
226 print(msg)
227
228 return result
229
230
231if __name__ == "__main__":
232 exit(main())