blob: 4d6d34575fedcc40e6abec5dff1936a0c0817947 [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
Jeremy Johnsone2b5e872023-09-14 17:02:09 +01005import json
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00006from enum import IntEnum
7from enum import unique
8from pathlib import Path
9
10import numpy as np
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010011from checker.color_print import LogColors
12from checker.color_print import print_color
13from checker.verifier import VerifierError
14from checker.verifier import VerifierLibrary
James Ward24dbc422022-10-19 12:20:31 +010015from generator.tosa_utils import float32_is_valid_bfloat16
Won Jeon2c34b462024-02-06 18:37:00 +000016from generator.tosa_utils import float32_is_valid_float8
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010017from schemavalidation.schemavalidation import TestDescSchemaValidator
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000018
19
20@unique
21class TestResult(IntEnum):
22 """Test result values."""
23
24 # Note: PASS must be 0 for command line return success
25 PASS = 0
26 MISSING_FILE = 1
27 INCORRECT_FORMAT = 2
28 MISMATCH = 3
29 INTERNAL_ERROR = 4
30
31
32TestResultErrorStr = [
33 "",
34 "Missing file",
35 "Incorrect format",
36 "Mismatch",
37 "Internal error",
38]
39##################################
40
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010041DEFAULT_FP_TOLERANCE = 1e-3
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010042result_printing = True
43
44
45def set_print_result(enabled):
46 """Set whether to print out or not."""
47 global result_printing
48 result_printing = enabled
49
50
51def _print_result(color, msg):
52 """Print out result."""
53 global result_printing
54 if result_printing:
55 print_color(color, msg)
56
57
58def compliance_check(
Jeremy Johnsonc8330812024-01-18 16:57:28 +000059 imp_result_data,
60 ref_result_data,
61 bnd_result_data,
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010062 test_name,
63 compliance_config,
64 ofm_name,
65 verify_lib_path,
66):
Jeremy Johnson39f34342023-11-27 15:02:04 +000067 if verify_lib_path is None:
68 error = "Please supply --verify-lib-path"
69 else:
70 error = None
71 try:
72 vlib = VerifierLibrary(verify_lib_path)
73 except VerifierError as e:
74 error = str(e)
75
76 if error is not None:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010077 _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}")
Jeremy Johnson39f34342023-11-27 15:02:04 +000078 msg = f"Could not load verfier library: {error}"
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010079 return (TestResult.INTERNAL_ERROR, 0.0, msg)
80
81 success = vlib.verify_data(
Jeremy Johnsonc8330812024-01-18 16:57:28 +000082 ofm_name, compliance_config, imp_result_data, ref_result_data, bnd_result_data
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010083 )
84 if success:
Jeremy Johnson39f34342023-11-27 15:02:04 +000085 _print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}")
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010086 return (TestResult.PASS, 0.0, "")
87 else:
88 _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
Jeremy Johnsonc8330812024-01-18 16:57:28 +000089 return (
90 TestResult.MISMATCH,
91 0.0,
92 f"Non-compliance results found for {ofm_name}",
93 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010094
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000095
96def test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010097 ref_result_path,
98 imp_result_path,
99 test_name=None,
James Ward24dbc422022-10-19 12:20:31 +0100100 quantize_tolerance=0,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100101 float_tolerance=DEFAULT_FP_TOLERANCE,
James Ward24dbc422022-10-19 12:20:31 +0100102 misc_checks=[],
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100103 test_desc=None,
104 bnd_result_path=None,
105 ofm_name=None,
106 verify_lib_path=None,
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000107):
108 """Check if the result is the same as the expected reference."""
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100109 if test_desc:
110 # New compliance method - first get test details
111 try:
112 TestDescSchemaValidator().validate_config(test_desc)
113 except Exception as e:
114 _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}")
115 msg = f"Incorrect test format: {e}"
116 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000117
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100118 if test_name is None:
119 test_name = "test"
120
121 paths = [imp_result_path, ref_result_path, bnd_result_path]
122 names = ["Implementation", "Reference", "Bounds"]
123 arrays = [None, None, None]
124
125 # Check the files exist and are in the right format
126 for idx, path in enumerate(paths):
127 name = names[idx]
128 if path is None and name == "Bounds":
129 # Bounds can be None - skip it
130 continue
131 if not path.is_file():
132 _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}")
133 msg = f"Missing {name} file: {str(path)}"
134 return (TestResult.MISSING_FILE, 0.0, msg)
135 try:
136 arrays[idx] = np.load(path)
137 except Exception as e:
138 _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}")
139 msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}"
140 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
141
142 if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]:
143 # Switch to using the verifier library for full compliance
144 if ofm_name is None:
145 ofm_name = test_desc["ofm_name"][0]
146 if len(test_desc["ofm_name"]) > 1:
147 _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}")
148 msg = "Must specify output name (ofm_name) to check as multiple found in desc.json"
149 return (TestResult.MISSING_FILE, 0.0, msg)
150
151 compliance_json = test_desc["meta"]["compliance"]
152
153 return compliance_check(
154 *arrays,
155 test_name,
156 compliance_json,
157 ofm_name,
158 verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100159 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100160
161 # Else continue with original checking method
162 test_result, reference_result, _ = arrays
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000163
164 # Type comparison
165 if test_result.dtype != reference_result.dtype:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100166 _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000167 msg = "Mismatch results type: Expected {}, got {}".format(
168 reference_result.dtype, test_result.dtype
169 )
170 return (TestResult.MISMATCH, 0.0, msg)
171
172 # Size comparison
173 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
174 # >= 0, allow that special case
175 test_result = np.squeeze(test_result)
176 reference_result = np.squeeze(reference_result)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100177 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000178
179 if np.shape(test_result) != np.shape(reference_result):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100180 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000181 msg = "Shapes mismatch: Reference {} vs {}".format(
182 np.shape(test_result), np.shape(reference_result)
183 )
184 return (TestResult.MISMATCH, 0.0, msg)
185
James Ward24dbc422022-10-19 12:20:31 +0100186 # Perform miscellaneous checks
187 if "bf16" in misc_checks:
188 # Ensure floats are valid bfloat16 values
189 test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
190 ref_res_is_bf16 = all(
191 [float32_is_valid_bfloat16(f) for f in reference_result.flat]
192 )
193 if not (test_res_is_bf16 and ref_res_is_bf16):
194 msg = (
195 "All output values must be valid bfloat16. "
196 "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
197 )
198 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
Won Jeon2c34b462024-02-06 18:37:00 +0000199 if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks:
200 # Ensure floats are valid float8 values
201 test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat])
202 ref_res_is_fp8 = all(
203 [float32_is_valid_float8(f) for f in reference_result.flat]
204 )
205 if not (test_res_is_fp8 and ref_res_is_fp8):
206 msg = (
207 "All output values must be valid float8. "
208 "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}"
209 )
210 return (TestResult.INCORRECT_FLOAT, 0.0, msg)
James Ward24dbc422022-10-19 12:20:31 +0100211
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000212 # for quantized test, allow +-(quantize_tolerance) error
Jeremy Johnson72dcab72023-10-30 10:28:21 +0000213 if reference_result.dtype in (
214 np.int8,
215 np.int16,
216 np.int32,
217 np.int64,
218 np.uint8,
219 np.uint16,
220 ):
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000221
222 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100223 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000224 return (TestResult.PASS, 0.0, "")
225 else:
226 tolerance = quantize_tolerance + 1
227 while not np.all(
228 np.absolute(reference_result - test_result) <= quantize_tolerance
229 ):
230 tolerance = tolerance + 1
231 if tolerance > 10:
232 break
233
234 if tolerance > 10:
235 msg = "Integer result does not match and is greater than 10 difference"
236 else:
237 msg = (
238 "Integer result does not match but is within {} difference".format(
239 tolerance
240 )
241 )
242 # Fall-through to below to add failure values
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100243 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000244
245 elif reference_result.dtype == bool:
246 assert test_result.dtype == bool
247 # All boolean values must match, xor will show up differences
248 test = np.array_equal(reference_result, test_result)
249 if np.all(test):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100250 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000251 return (TestResult.PASS, 0.0, "")
252 msg = "Boolean result does not match"
253 tolerance = 0.0
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100254 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000255 # Fall-through to below to add failure values
256
James Ward8b390432022-08-12 20:48:56 +0100257 # TODO: update for fp16 tolerance
258 elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000259 tolerance = float_tolerance
260 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100261 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000262 return (TestResult.PASS, tolerance, "")
263 msg = "Float result does not match within tolerance of {}".format(tolerance)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100264 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000265 # Fall-through to below to add failure values
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000266 else:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100267 _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000268 msg = "Unsupported results type: {}".format(reference_result.dtype)
269 return (TestResult.MISMATCH, 0.0, msg)
270
271 # Fall-through for mismatch failure to add values to msg
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100272 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100273 np.set_printoptions(threshold=128, edgeitems=2)
274
275 if difference is not None:
276 tolerance_needed = np.amax(np.absolute(difference))
277 msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)
278
279 msg = "{}\n>> reference_result: {}\n{}".format(
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000280 msg, reference_result.shape, reference_result
281 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100282 msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)
283
284 if difference is not None:
285 msg = "{}\n!! difference_result: \n{}".format(msg, difference)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000286 return (TestResult.MISMATCH, tolerance, msg)
287
288
289def main(argv=None):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100290 """Check that the supplied reference and result files have the same contents."""
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000291 parser = argparse.ArgumentParser()
292 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100293 "ref_result_path",
294 type=Path,
295 help="path to the reference model result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000296 )
297 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100298 "imp_result_path",
299 type=Path,
300 help="path to the implementation result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000301 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100302 parser.add_argument(
303 "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
304 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100305 parser.add_argument(
Jeremy Johnson39f34342023-11-27 15:02:04 +0000306 "--test-path", type=Path, help="path to the test that produced the results"
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100307 )
Jeremy Johnson39f34342023-11-27 15:02:04 +0000308 # Deprecate the incorrectly formatted option by hiding it
309 parser.add_argument("--test_path", type=Path, help=argparse.SUPPRESS)
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100310 parser.add_argument(
311 "--bnd-result-path",
312 type=Path,
313 help="path to the reference model bounds result file for the dot product compliance check",
314 )
315 parser.add_argument(
316 "--ofm-name",
317 type=str,
318 help="name of the output tensor to check, defaults to the first ofm_name listed in the test",
319 )
320 parser.add_argument(
321 "--verify-lib-path",
322 type=Path,
323 help="path to TOSA verify library",
324 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000325 args = parser.parse_args(argv)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000326
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100327 if args.test_path:
328 # Get details from the test path
329 test_desc_path = args.test_path / "desc.json"
330 if not args.test_path.is_dir() or not test_desc_path.is_file():
331 print(f"Invalid test directory {str(args.test_path)}")
332 return TestResult.MISSING_FILE
333
334 try:
335 with test_desc_path.open("r") as fd:
336 test_desc = json.load(fd)
337 except Exception as e:
338 print(f"Invalid test description file {str(test_desc_path)}: {e}")
339 return TestResult.INCORRECT_FORMAT
340 test_name = args.test_path.name
341 else:
342 test_desc = None
343 test_name = None
344
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100345 result, tolerance, msg = test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100346 args.ref_result_path,
347 args.imp_result_path,
348 float_tolerance=args.fp_tolerance,
349 test_name=test_name,
350 test_desc=test_desc,
351 bnd_result_path=args.bnd_result_path,
352 ofm_name=args.ofm_name,
353 verify_lib_path=args.verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100354 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000355 if result != TestResult.PASS:
356 print(msg)
357
358 return result
359
360
361if __name__ == "__main__":
362 exit(main())