blob: 694837878c400419330b38734e15112dd7bed566 [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
Jeremy Johnsone2b5e872023-09-14 17:02:09 +01002# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
Jeremy Johnsone2b5e872023-09-14 17:02:09 +01005import json
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00006from enum import IntEnum
7from enum import unique
8from pathlib import Path
9
10import numpy as np
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010011from checker.color_print import LogColors
12from checker.color_print import print_color
13from checker.verifier import VerifierError
14from checker.verifier import VerifierLibrary
James Ward24dbc422022-10-19 12:20:31 +010015from generator.tosa_utils import float32_is_valid_bfloat16
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010016from schemavalidation.schemavalidation import TestDescSchemaValidator
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000017
18
19@unique
20class TestResult(IntEnum):
21 """Test result values."""
22
23 # Note: PASS must be 0 for command line return success
24 PASS = 0
25 MISSING_FILE = 1
26 INCORRECT_FORMAT = 2
27 MISMATCH = 3
28 INTERNAL_ERROR = 4
29
30
31TestResultErrorStr = [
32 "",
33 "Missing file",
34 "Incorrect format",
35 "Mismatch",
36 "Internal error",
37]
38##################################
39
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040DEFAULT_FP_TOLERANCE = 1e-3
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010041result_printing = True
42
43
44def set_print_result(enabled):
45 """Set whether to print out or not."""
46 global result_printing
47 result_printing = enabled
48
49
50def _print_result(color, msg):
51 """Print out result."""
52 global result_printing
53 if result_printing:
54 print_color(color, msg)
55
56
57def compliance_check(
58 imp_result_path,
59 ref_result_path,
60 bnd_result_path,
61 test_name,
62 compliance_config,
63 ofm_name,
64 verify_lib_path,
65):
Jeremy Johnson39f34342023-11-27 15:02:04 +000066 if verify_lib_path is None:
67 error = "Please supply --verify-lib-path"
68 else:
69 error = None
70 try:
71 vlib = VerifierLibrary(verify_lib_path)
72 except VerifierError as e:
73 error = str(e)
74
75 if error is not None:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010076 _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}")
Jeremy Johnson39f34342023-11-27 15:02:04 +000077 msg = f"Could not load verfier library: {error}"
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010078 return (TestResult.INTERNAL_ERROR, 0.0, msg)
79
80 success = vlib.verify_data(
81 ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path
82 )
83 if success:
Jeremy Johnson39f34342023-11-27 15:02:04 +000084 _print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}")
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010085 return (TestResult.PASS, 0.0, "")
86 else:
87 _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
Jeremy Johnson6ce35022023-11-21 11:22:22 +000088 return (TestResult.MISMATCH, 0.0, "Non-compliance results found")
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010089
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000090
91def test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010092 ref_result_path,
93 imp_result_path,
94 test_name=None,
James Ward24dbc422022-10-19 12:20:31 +010095 quantize_tolerance=0,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010096 float_tolerance=DEFAULT_FP_TOLERANCE,
James Ward24dbc422022-10-19 12:20:31 +010097 misc_checks=[],
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010098 test_desc=None,
99 bnd_result_path=None,
100 ofm_name=None,
101 verify_lib_path=None,
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000102):
103 """Check if the result is the same as the expected reference."""
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100104 if test_desc:
105 # New compliance method - first get test details
106 try:
107 TestDescSchemaValidator().validate_config(test_desc)
108 except Exception as e:
109 _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}")
110 msg = f"Incorrect test format: {e}"
111 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000112
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100113 if test_name is None:
114 test_name = "test"
115
116 paths = [imp_result_path, ref_result_path, bnd_result_path]
117 names = ["Implementation", "Reference", "Bounds"]
118 arrays = [None, None, None]
119
120 # Check the files exist and are in the right format
121 for idx, path in enumerate(paths):
122 name = names[idx]
123 if path is None and name == "Bounds":
124 # Bounds can be None - skip it
125 continue
126 if not path.is_file():
127 _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}")
128 msg = f"Missing {name} file: {str(path)}"
129 return (TestResult.MISSING_FILE, 0.0, msg)
130 try:
131 arrays[idx] = np.load(path)
132 except Exception as e:
133 _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}")
134 msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}"
135 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
136
137 if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]:
138 # Switch to using the verifier library for full compliance
139 if ofm_name is None:
140 ofm_name = test_desc["ofm_name"][0]
141 if len(test_desc["ofm_name"]) > 1:
142 _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}")
143 msg = "Must specify output name (ofm_name) to check as multiple found in desc.json"
144 return (TestResult.MISSING_FILE, 0.0, msg)
145
146 compliance_json = test_desc["meta"]["compliance"]
147
148 return compliance_check(
149 *arrays,
150 test_name,
151 compliance_json,
152 ofm_name,
153 verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100154 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100155
156 # Else continue with original checking method
157 test_result, reference_result, _ = arrays
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000158
159 # Type comparison
160 if test_result.dtype != reference_result.dtype:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100161 _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000162 msg = "Mismatch results type: Expected {}, got {}".format(
163 reference_result.dtype, test_result.dtype
164 )
165 return (TestResult.MISMATCH, 0.0, msg)
166
167 # Size comparison
168 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
169 # >= 0, allow that special case
170 test_result = np.squeeze(test_result)
171 reference_result = np.squeeze(reference_result)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100172 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000173
174 if np.shape(test_result) != np.shape(reference_result):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100175 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000176 msg = "Shapes mismatch: Reference {} vs {}".format(
177 np.shape(test_result), np.shape(reference_result)
178 )
179 return (TestResult.MISMATCH, 0.0, msg)
180
James Ward24dbc422022-10-19 12:20:31 +0100181 # Perform miscellaneous checks
182 if "bf16" in misc_checks:
183 # Ensure floats are valid bfloat16 values
184 test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
185 ref_res_is_bf16 = all(
186 [float32_is_valid_bfloat16(f) for f in reference_result.flat]
187 )
188 if not (test_res_is_bf16 and ref_res_is_bf16):
189 msg = (
190 "All output values must be valid bfloat16. "
191 "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
192 )
193 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
194
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000195 # for quantized test, allow +-(quantize_tolerance) error
Jeremy Johnson72dcab72023-10-30 10:28:21 +0000196 if reference_result.dtype in (
197 np.int8,
198 np.int16,
199 np.int32,
200 np.int64,
201 np.uint8,
202 np.uint16,
203 ):
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000204
205 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100206 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000207 return (TestResult.PASS, 0.0, "")
208 else:
209 tolerance = quantize_tolerance + 1
210 while not np.all(
211 np.absolute(reference_result - test_result) <= quantize_tolerance
212 ):
213 tolerance = tolerance + 1
214 if tolerance > 10:
215 break
216
217 if tolerance > 10:
218 msg = "Integer result does not match and is greater than 10 difference"
219 else:
220 msg = (
221 "Integer result does not match but is within {} difference".format(
222 tolerance
223 )
224 )
225 # Fall-through to below to add failure values
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100226 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000227
228 elif reference_result.dtype == bool:
229 assert test_result.dtype == bool
230 # All boolean values must match, xor will show up differences
231 test = np.array_equal(reference_result, test_result)
232 if np.all(test):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100233 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000234 return (TestResult.PASS, 0.0, "")
235 msg = "Boolean result does not match"
236 tolerance = 0.0
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100237 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000238 # Fall-through to below to add failure values
239
James Ward8b390432022-08-12 20:48:56 +0100240 # TODO: update for fp16 tolerance
241 elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000242 tolerance = float_tolerance
243 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100244 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000245 return (TestResult.PASS, tolerance, "")
246 msg = "Float result does not match within tolerance of {}".format(tolerance)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100247 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000248 # Fall-through to below to add failure values
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000249 else:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100250 _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000251 msg = "Unsupported results type: {}".format(reference_result.dtype)
252 return (TestResult.MISMATCH, 0.0, msg)
253
254 # Fall-through for mismatch failure to add values to msg
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100255 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100256 np.set_printoptions(threshold=128, edgeitems=2)
257
258 if difference is not None:
259 tolerance_needed = np.amax(np.absolute(difference))
260 msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)
261
262 msg = "{}\n>> reference_result: {}\n{}".format(
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000263 msg, reference_result.shape, reference_result
264 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100265 msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)
266
267 if difference is not None:
268 msg = "{}\n!! difference_result: \n{}".format(msg, difference)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000269 return (TestResult.MISMATCH, tolerance, msg)
270
271
272def main(argv=None):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100273 """Check that the supplied reference and result files have the same contents."""
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000274 parser = argparse.ArgumentParser()
275 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100276 "ref_result_path",
277 type=Path,
278 help="path to the reference model result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000279 )
280 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100281 "imp_result_path",
282 type=Path,
283 help="path to the implementation result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000284 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100285 parser.add_argument(
286 "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
287 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100288 parser.add_argument(
Jeremy Johnson39f34342023-11-27 15:02:04 +0000289 "--test-path", type=Path, help="path to the test that produced the results"
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100290 )
Jeremy Johnson39f34342023-11-27 15:02:04 +0000291 # Deprecate the incorrectly formatted option by hiding it
292 parser.add_argument("--test_path", type=Path, help=argparse.SUPPRESS)
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100293 parser.add_argument(
294 "--bnd-result-path",
295 type=Path,
296 help="path to the reference model bounds result file for the dot product compliance check",
297 )
298 parser.add_argument(
299 "--ofm-name",
300 type=str,
301 help="name of the output tensor to check, defaults to the first ofm_name listed in the test",
302 )
303 parser.add_argument(
304 "--verify-lib-path",
305 type=Path,
306 help="path to TOSA verify library",
307 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000308 args = parser.parse_args(argv)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000309
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100310 if args.test_path:
311 # Get details from the test path
312 test_desc_path = args.test_path / "desc.json"
313 if not args.test_path.is_dir() or not test_desc_path.is_file():
314 print(f"Invalid test directory {str(args.test_path)}")
315 return TestResult.MISSING_FILE
316
317 try:
318 with test_desc_path.open("r") as fd:
319 test_desc = json.load(fd)
320 except Exception as e:
321 print(f"Invalid test description file {str(test_desc_path)}: {e}")
322 return TestResult.INCORRECT_FORMAT
323 test_name = args.test_path.name
324 else:
325 test_desc = None
326 test_name = None
327
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100328 result, tolerance, msg = test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100329 args.ref_result_path,
330 args.imp_result_path,
331 float_tolerance=args.fp_tolerance,
332 test_name=test_name,
333 test_desc=test_desc,
334 bnd_result_path=args.bnd_result_path,
335 ofm_name=args.ofm_name,
336 verify_lib_path=args.verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100337 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000338 if result != TestResult.PASS:
339 print(msg)
340
341 return result
342
343
344if __name__ == "__main__":
345 exit(main())