blob: 38ed510d1d1c3c39c1b18c3bbc72d34530fcb9ed [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
Jeremy Johnsone2b5e872023-09-14 17:02:09 +01002# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00003# SPDX-License-Identifier: Apache-2.0
4import argparse
Jeremy Johnsone2b5e872023-09-14 17:02:09 +01005import json
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00006from enum import IntEnum
7from enum import unique
8from pathlib import Path
9
10import numpy as np
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010011from checker.color_print import LogColors
12from checker.color_print import print_color
13from checker.verifier import VerifierError
14from checker.verifier import VerifierLibrary
James Ward24dbc422022-10-19 12:20:31 +010015from generator.tosa_utils import float32_is_valid_bfloat16
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010016from schemavalidation.schemavalidation import TestDescSchemaValidator
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000017
18
19@unique
20class TestResult(IntEnum):
21 """Test result values."""
22
23 # Note: PASS must be 0 for command line return success
24 PASS = 0
25 MISSING_FILE = 1
26 INCORRECT_FORMAT = 2
27 MISMATCH = 3
28 INTERNAL_ERROR = 4
29
30
31TestResultErrorStr = [
32 "",
33 "Missing file",
34 "Incorrect format",
35 "Mismatch",
36 "Internal error",
37]
38##################################
39
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040DEFAULT_FP_TOLERANCE = 1e-3
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010041result_printing = True
42
43
44def set_print_result(enabled):
45 """Set whether to print out or not."""
46 global result_printing
47 result_printing = enabled
48
49
50def _print_result(color, msg):
51 """Print out result."""
52 global result_printing
53 if result_printing:
54 print_color(color, msg)
55
56
57def compliance_check(
58 imp_result_path,
59 ref_result_path,
60 bnd_result_path,
61 test_name,
62 compliance_config,
63 ofm_name,
64 verify_lib_path,
65):
66 try:
67 vlib = VerifierLibrary(verify_lib_path)
68 except VerifierError as e:
69 _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}")
70 msg = f"Could not load verfier library: {str(e)}"
71 return (TestResult.INTERNAL_ERROR, 0.0, msg)
72
73 success = vlib.verify_data(
74 ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path
75 )
76 if success:
77 _print_result(LogColors.GREEN, f"Results PASS {test_name}")
78 return (TestResult.PASS, 0.0, "")
79 else:
80 _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
81 return (TestResult.MISMATCH, 0.0, "Non-compliance implementation results found")
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010082
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000083
84def test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010085 ref_result_path,
86 imp_result_path,
87 test_name=None,
James Ward24dbc422022-10-19 12:20:31 +010088 quantize_tolerance=0,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010089 float_tolerance=DEFAULT_FP_TOLERANCE,
James Ward24dbc422022-10-19 12:20:31 +010090 misc_checks=[],
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010091 test_desc=None,
92 bnd_result_path=None,
93 ofm_name=None,
94 verify_lib_path=None,
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000095):
96 """Check if the result is the same as the expected reference."""
Jeremy Johnsone2b5e872023-09-14 17:02:09 +010097 if test_desc:
98 # New compliance method - first get test details
99 try:
100 TestDescSchemaValidator().validate_config(test_desc)
101 except Exception as e:
102 _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}")
103 msg = f"Incorrect test format: {e}"
104 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000105
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100106 if test_name is None:
107 test_name = "test"
108
109 paths = [imp_result_path, ref_result_path, bnd_result_path]
110 names = ["Implementation", "Reference", "Bounds"]
111 arrays = [None, None, None]
112
113 # Check the files exist and are in the right format
114 for idx, path in enumerate(paths):
115 name = names[idx]
116 if path is None and name == "Bounds":
117 # Bounds can be None - skip it
118 continue
119 if not path.is_file():
120 _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}")
121 msg = f"Missing {name} file: {str(path)}"
122 return (TestResult.MISSING_FILE, 0.0, msg)
123 try:
124 arrays[idx] = np.load(path)
125 except Exception as e:
126 _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}")
127 msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}"
128 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
129
130 if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]:
131 # Switch to using the verifier library for full compliance
132 if ofm_name is None:
133 ofm_name = test_desc["ofm_name"][0]
134 if len(test_desc["ofm_name"]) > 1:
135 _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}")
136 msg = "Must specify output name (ofm_name) to check as multiple found in desc.json"
137 return (TestResult.MISSING_FILE, 0.0, msg)
138
139 compliance_json = test_desc["meta"]["compliance"]
140
141 return compliance_check(
142 *arrays,
143 test_name,
144 compliance_json,
145 ofm_name,
146 verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100147 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100148
149 # Else continue with original checking method
150 test_result, reference_result, _ = arrays
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000151
152 # Type comparison
153 if test_result.dtype != reference_result.dtype:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100154 _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000155 msg = "Mismatch results type: Expected {}, got {}".format(
156 reference_result.dtype, test_result.dtype
157 )
158 return (TestResult.MISMATCH, 0.0, msg)
159
160 # Size comparison
161 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
162 # >= 0, allow that special case
163 test_result = np.squeeze(test_result)
164 reference_result = np.squeeze(reference_result)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100165 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000166
167 if np.shape(test_result) != np.shape(reference_result):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100168 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000169 msg = "Shapes mismatch: Reference {} vs {}".format(
170 np.shape(test_result), np.shape(reference_result)
171 )
172 return (TestResult.MISMATCH, 0.0, msg)
173
James Ward24dbc422022-10-19 12:20:31 +0100174 # Perform miscellaneous checks
175 if "bf16" in misc_checks:
176 # Ensure floats are valid bfloat16 values
177 test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
178 ref_res_is_bf16 = all(
179 [float32_is_valid_bfloat16(f) for f in reference_result.flat]
180 )
181 if not (test_res_is_bf16 and ref_res_is_bf16):
182 msg = (
183 "All output values must be valid bfloat16. "
184 "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
185 )
186 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
187
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000188 # for quantized test, allow +-(quantize_tolerance) error
189 if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
190
191 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100192 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000193 return (TestResult.PASS, 0.0, "")
194 else:
195 tolerance = quantize_tolerance + 1
196 while not np.all(
197 np.absolute(reference_result - test_result) <= quantize_tolerance
198 ):
199 tolerance = tolerance + 1
200 if tolerance > 10:
201 break
202
203 if tolerance > 10:
204 msg = "Integer result does not match and is greater than 10 difference"
205 else:
206 msg = (
207 "Integer result does not match but is within {} difference".format(
208 tolerance
209 )
210 )
211 # Fall-through to below to add failure values
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100212 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000213
214 elif reference_result.dtype == bool:
215 assert test_result.dtype == bool
216 # All boolean values must match, xor will show up differences
217 test = np.array_equal(reference_result, test_result)
218 if np.all(test):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100219 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000220 return (TestResult.PASS, 0.0, "")
221 msg = "Boolean result does not match"
222 tolerance = 0.0
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100223 difference = None
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000224 # Fall-through to below to add failure values
225
James Ward8b390432022-08-12 20:48:56 +0100226 # TODO: update for fp16 tolerance
227 elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000228 tolerance = float_tolerance
229 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100230 _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000231 return (TestResult.PASS, tolerance, "")
232 msg = "Float result does not match within tolerance of {}".format(tolerance)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100233 difference = reference_result - test_result
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000234 # Fall-through to below to add failure values
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000235 else:
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100236 _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000237 msg = "Unsupported results type: {}".format(reference_result.dtype)
238 return (TestResult.MISMATCH, 0.0, msg)
239
240 # Fall-through for mismatch failure to add values to msg
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100241 _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100242 np.set_printoptions(threshold=128, edgeitems=2)
243
244 if difference is not None:
245 tolerance_needed = np.amax(np.absolute(difference))
246 msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)
247
248 msg = "{}\n>> reference_result: {}\n{}".format(
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000249 msg, reference_result.shape, reference_result
250 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100251 msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)
252
253 if difference is not None:
254 msg = "{}\n!! difference_result: \n{}".format(msg, difference)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000255 return (TestResult.MISMATCH, tolerance, msg)
256
257
258def main(argv=None):
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100259 """Check that the supplied reference and result files have the same contents."""
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000260 parser = argparse.ArgumentParser()
261 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100262 "ref_result_path",
263 type=Path,
264 help="path to the reference model result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000265 )
266 parser.add_argument(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100267 "imp_result_path",
268 type=Path,
269 help="path to the implementation result file to check",
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000270 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100271 parser.add_argument(
272 "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
273 )
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100274 parser.add_argument(
275 "--test_path", type=Path, help="path to the test that produced the results"
276 )
277 parser.add_argument(
278 "--bnd-result-path",
279 type=Path,
280 help="path to the reference model bounds result file for the dot product compliance check",
281 )
282 parser.add_argument(
283 "--ofm-name",
284 type=str,
285 help="name of the output tensor to check, defaults to the first ofm_name listed in the test",
286 )
287 parser.add_argument(
288 "--verify-lib-path",
289 type=Path,
290 help="path to TOSA verify library",
291 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000292 args = parser.parse_args(argv)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000293
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100294 if args.test_path:
295 # Get details from the test path
296 test_desc_path = args.test_path / "desc.json"
297 if not args.test_path.is_dir() or not test_desc_path.is_file():
298 print(f"Invalid test directory {str(args.test_path)}")
299 return TestResult.MISSING_FILE
300
301 try:
302 with test_desc_path.open("r") as fd:
303 test_desc = json.load(fd)
304 except Exception as e:
305 print(f"Invalid test description file {str(test_desc_path)}: {e}")
306 return TestResult.INCORRECT_FORMAT
307 test_name = args.test_path.name
308 else:
309 test_desc = None
310 test_name = None
311
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100312 result, tolerance, msg = test_check(
Jeremy Johnsone2b5e872023-09-14 17:02:09 +0100313 args.ref_result_path,
314 args.imp_result_path,
315 float_tolerance=args.fp_tolerance,
316 test_name=test_name,
317 test_desc=test_desc,
318 bnd_result_path=args.bnd_result_path,
319 ofm_name=args.ofm_name,
320 verify_lib_path=args.verify_lib_path,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100321 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000322 if result != TestResult.PASS:
323 print(msg)
324
325 return result
326
327
328if __name__ == "__main__":
329 exit(main())