blob: 212c8094e35d0857548b9c4a978d76f763575a9f [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
Jeremy Johnsone2b5e872023-09-14 17:02:09 +01005import json
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00006from enum import IntEnum
7from enum import unique
8from pathlib import Path
9
10import numpy as np
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010011from checker.color_print import LogColors
12from checker.color_print import print_color
13from checker.verifier import VerifierError
14from checker.verifier import VerifierLibrary
James Ward24dbc422022-10-19 12:20:31 +010015from generator.tosa_utils import float32_is_valid_bfloat16
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010016from schemavalidation.schemavalidation import TestDescSchemaValidator
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000017
18
19@unique
20class TestResult(IntEnum):
21 """Test result values."""
22
23 # Note: PASS must be 0 for command line return success
24 PASS = 0
25 MISSING_FILE = 1
26 INCORRECT_FORMAT = 2
27 MISMATCH = 3
28 INTERNAL_ERROR = 4
29
30
31TestResultErrorStr = [
32 "",
33 "Missing file",
34 "Incorrect format",
35 "Mismatch",
36 "Internal error",
37]
38##################################
39
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040DEFAULT_FP_TOLERANCE = 1e-3
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010041result_printing = True
42
43
44def set_print_result(enabled):
45 """Set whether to print out or not."""
46 global result_printing
47 result_printing = enabled
48
49
50def _print_result(color, msg):
51 """Print out result."""
52 global result_printing
53 if result_printing:
54 print_color(color, msg)
55
56
57def compliance_check(
Jeremy Johnsonc8330812024-01-18 16:57:28 +000058 imp_result_data,
59 ref_result_data,
60 bnd_result_data,
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010061 test_name,
62 compliance_config,
63 ofm_name,
64 verify_lib_path,
65):
Jeremy Johnson39f34342023-11-27 15:02:04 +000066 if verify_lib_path is None:
67 error = "Please supply --verify-lib-path"
68 else:
69 error = None
70 try:
71 vlib = VerifierLibrary(verify_lib_path)
72 except VerifierError as e:
73 error = str(e)
74
75 if error is not None:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010076 _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}")
Jeremy Johnson39f34342023-11-27 15:02:04 +000077 msg = f"Could not load verfier library: {error}"
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010078 return (TestResult.INTERNAL_ERROR, 0.0, msg)
79
80 success = vlib.verify_data(
Jeremy Johnsonc8330812024-01-18 16:57:28 +000081 ofm_name, compliance_config, imp_result_data, ref_result_data, bnd_result_data
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010082 )
83 if success:
Jeremy Johnson39f34342023-11-27 15:02:04 +000084 _print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}")
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010085 return (TestResult.PASS, 0.0, "")
86 else:
87 _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
Jeremy Johnsonc8330812024-01-18 16:57:28 +000088 return (
89 TestResult.MISMATCH,
90 0.0,
91 f"Non-compliance results found for {ofm_name}",
92 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010093
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000094
95def test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010096 ref_result_path,
97 imp_result_path,
98 test_name=None,
James Ward24dbc422022-10-19 12:20:31 +010099 quantize_tolerance=0,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100100 float_tolerance=DEFAULT_FP_TOLERANCE,
James Ward24dbc422022-10-19 12:20:31 +0100101 misc_checks=[],
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100102 test_desc=None,
103 bnd_result_path=None,
104 ofm_name=None,
105 verify_lib_path=None,
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000106):
107 """Check if the result is the same as the expected reference."""
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100108 if test_desc:
109 # New compliance method - first get test details
110 try:
111 TestDescSchemaValidator().validate_config(test_desc)
112 except Exception as e:
113 _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}")
114 msg = f"Incorrect test format: {e}"
115 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000116
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100117 if test_name is None:
118 test_name = "test"
119
120 paths = [imp_result_path, ref_result_path, bnd_result_path]
121 names = ["Implementation", "Reference", "Bounds"]
122 arrays = [None, None, None]
123
124 # Check the files exist and are in the right format
125 for idx, path in enumerate(paths):
126 name = names[idx]
127 if path is None and name == "Bounds":
128 # Bounds can be None - skip it
129 continue
130 if not path.is_file():
131 _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}")
132 msg = f"Missing {name} file: {str(path)}"
133 return (TestResult.MISSING_FILE, 0.0, msg)
134 try:
135 arrays[idx] = np.load(path)
136 except Exception as e:
137 _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}")
138 msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}"
139 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
140
141 if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]:
142 # Switch to using the verifier library for full compliance
143 if ofm_name is None:
144 ofm_name = test_desc["ofm_name"][0]
145 if len(test_desc["ofm_name"]) > 1:
146 _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}")
147 msg = "Must specify output name (ofm_name) to check as multiple found in desc.json"
148 return (TestResult.MISSING_FILE, 0.0, msg)
149
150 compliance_json = test_desc["meta"]["compliance"]
151
152 return compliance_check(
153 *arrays,
154 test_name,
155 compliance_json,
156 ofm_name,
157 verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100159
160 # Else continue with original checking method
161 test_result, reference_result, _ = arrays
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000162
163 # Type comparison
164 if test_result.dtype != reference_result.dtype:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100165 _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000166 msg = "Mismatch results type: Expected {}, got {}".format(
167 reference_result.dtype, test_result.dtype
168 )
169 return (TestResult.MISMATCH, 0.0, msg)
170
171 # Size comparison
172 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
173 # >= 0, allow that special case
174 test_result = np.squeeze(test_result)
175 reference_result = np.squeeze(reference_result)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100176 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000177
178 if np.shape(test_result) != np.shape(reference_result):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100179 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000180 msg = "Shapes mismatch: Reference {} vs {}".format(
181 np.shape(test_result), np.shape(reference_result)
182 )
183 return (TestResult.MISMATCH, 0.0, msg)
184
James Ward24dbc422022-10-19 12:20:31 +0100185 # Perform miscellaneous checks
186 if "bf16" in misc_checks:
187 # Ensure floats are valid bfloat16 values
188 test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
189 ref_res_is_bf16 = all(
190 [float32_is_valid_bfloat16(f) for f in reference_result.flat]
191 )
192 if not (test_res_is_bf16 and ref_res_is_bf16):
193 msg = (
194 "All output values must be valid bfloat16. "
195 "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
196 )
197 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
198
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000199 # for quantized test, allow +-(quantize_tolerance) error
Jeremy Johnson72dcab72023-10-30 10:28:21 +0000200 if reference_result.dtype in (
201 np.int8,
202 np.int16,
203 np.int32,
204 np.int64,
205 np.uint8,
206 np.uint16,
207 ):
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000208
209 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100210 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000211 return (TestResult.PASS, 0.0, "")
212 else:
213 tolerance = quantize_tolerance + 1
214 while not np.all(
215 np.absolute(reference_result - test_result) <= quantize_tolerance
216 ):
217 tolerance = tolerance + 1
218 if tolerance > 10:
219 break
220
221 if tolerance > 10:
222 msg = "Integer result does not match and is greater than 10 difference"
223 else:
224 msg = (
225 "Integer result does not match but is within {} difference".format(
226 tolerance
227 )
228 )
229 # Fall-through to below to add failure values
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100230 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000231
232 elif reference_result.dtype == bool:
233 assert test_result.dtype == bool
234 # All boolean values must match, xor will show up differences
235 test = np.array_equal(reference_result, test_result)
236 if np.all(test):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100237 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000238 return (TestResult.PASS, 0.0, "")
239 msg = "Boolean result does not match"
240 tolerance = 0.0
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100241 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000242 # Fall-through to below to add failure values
243
James Ward8b390432022-08-12 20:48:56 +0100244 # TODO: update for fp16 tolerance
245 elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000246 tolerance = float_tolerance
247 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100248 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000249 return (TestResult.PASS, tolerance, "")
250 msg = "Float result does not match within tolerance of {}".format(tolerance)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100251 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000252 # Fall-through to below to add failure values
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000253 else:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100254 _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000255 msg = "Unsupported results type: {}".format(reference_result.dtype)
256 return (TestResult.MISMATCH, 0.0, msg)
257
258 # Fall-through for mismatch failure to add values to msg
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100259 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100260 np.set_printoptions(threshold=128, edgeitems=2)
261
262 if difference is not None:
263 tolerance_needed = np.amax(np.absolute(difference))
264 msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)
265
266 msg = "{}\n>> reference_result: {}\n{}".format(
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000267 msg, reference_result.shape, reference_result
268 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100269 msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)
270
271 if difference is not None:
272 msg = "{}\n!! difference_result: \n{}".format(msg, difference)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000273 return (TestResult.MISMATCH, tolerance, msg)
274
275
276def main(argv=None):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100277 """Check that the supplied reference and result files have the same contents."""
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000278 parser = argparse.ArgumentParser()
279 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100280 "ref_result_path",
281 type=Path,
282 help="path to the reference model result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000283 )
284 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100285 "imp_result_path",
286 type=Path,
287 help="path to the implementation result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000288 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100289 parser.add_argument(
290 "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
291 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100292 parser.add_argument(
Jeremy Johnson39f34342023-11-27 15:02:04 +0000293 "--test-path", type=Path, help="path to the test that produced the results"
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100294 )
Jeremy Johnson39f34342023-11-27 15:02:04 +0000295 # Deprecate the incorrectly formatted option by hiding it
296 parser.add_argument("--test_path", type=Path, help=argparse.SUPPRESS)
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100297 parser.add_argument(
298 "--bnd-result-path",
299 type=Path,
300 help="path to the reference model bounds result file for the dot product compliance check",
301 )
302 parser.add_argument(
303 "--ofm-name",
304 type=str,
305 help="name of the output tensor to check, defaults to the first ofm_name listed in the test",
306 )
307 parser.add_argument(
308 "--verify-lib-path",
309 type=Path,
310 help="path to TOSA verify library",
311 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000312 args = parser.parse_args(argv)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000313
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100314 if args.test_path:
315 # Get details from the test path
316 test_desc_path = args.test_path / "desc.json"
317 if not args.test_path.is_dir() or not test_desc_path.is_file():
318 print(f"Invalid test directory {str(args.test_path)}")
319 return TestResult.MISSING_FILE
320
321 try:
322 with test_desc_path.open("r") as fd:
323 test_desc = json.load(fd)
324 except Exception as e:
325 print(f"Invalid test description file {str(test_desc_path)}: {e}")
326 return TestResult.INCORRECT_FORMAT
327 test_name = args.test_path.name
328 else:
329 test_desc = None
330 test_name = None
331
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100332 result, tolerance, msg = test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100333 args.ref_result_path,
334 args.imp_result_path,
335 float_tolerance=args.fp_tolerance,
336 test_name=test_name,
337 test_desc=test_desc,
338 bnd_result_path=args.bnd_result_path,
339 ofm_name=args.ofm_name,
340 verify_lib_path=args.verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100341 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000342 if result != TestResult.PASS:
343 print(msg)
344
345 return result
346
347
348if __name__ == "__main__":
349 exit(main())