blob: 3a15de951887c03926f9df2a430e19754f115a3e [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
2# Copyright (c) 2020-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import argparse
5import os
6from enum import Enum
7from enum import IntEnum
8from enum import unique
9from pathlib import Path
10
11import numpy as np
12
13##################################
14no_color_printing = False
15
16
17@unique
18class LogColors(Enum):
19 """Shell escape sequence colors for logging."""
20
21 NONE = "\u001b[0m"
22 GREEN = "\u001b[32;1m"
23 RED = "\u001b[31;1m"
24 YELLOW = "\u001b[33;1m"
25 BOLD_WHITE = "\u001b[1m"
26
27
28def print_color(color, msg):
29 """Print color status messages if enabled."""
30 if no_color_printing:
31 print(msg)
32 else:
33 print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
34
35
36@unique
37class TestResult(IntEnum):
38 """Test result values."""
39
40 # Note: PASS must be 0 for command line return success
41 PASS = 0
42 MISSING_FILE = 1
43 INCORRECT_FORMAT = 2
44 MISMATCH = 3
45 INTERNAL_ERROR = 4
46
47
48TestResultErrorStr = [
49 "",
50 "Missing file",
51 "Incorrect format",
52 "Mismatch",
53 "Internal error",
54]
55##################################
56
57
58def test_check(
59 reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
60):
61 """Check if the result is the same as the expected reference."""
62 if not os.path.isfile(reference):
63 print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
64 msg = "Missing reference file: {}".format(reference)
65 return (TestResult.MISSING_FILE, 0.0, msg)
66 if not os.path.isfile(result):
67 print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
68 msg = "Missing result file: {}".format(result)
69 return (TestResult.MISSING_FILE, 0.0, msg)
70
71 try:
72 test_result = np.load(result)
73 except Exception as e:
74 print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
75 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e)
76 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
77 try:
78 reference_result = np.load(reference)
79 except Exception as e:
80 print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
81 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
82 reference, e
83 )
84 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
85
86 # Type comparison
87 if test_result.dtype != reference_result.dtype:
88 print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
89 msg = "Mismatch results type: Expected {}, got {}".format(
90 reference_result.dtype, test_result.dtype
91 )
92 return (TestResult.MISMATCH, 0.0, msg)
93
94 # Size comparison
95 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
96 # >= 0, allow that special case
97 test_result = np.squeeze(test_result)
98 reference_result = np.squeeze(reference_result)
99
100 if np.shape(test_result) != np.shape(reference_result):
101 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
102 msg = "Shapes mismatch: Reference {} vs {}".format(
103 np.shape(test_result), np.shape(reference_result)
104 )
105 return (TestResult.MISMATCH, 0.0, msg)
106
107 # for quantized test, allow +-(quantize_tolerance) error
108 if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
109
110 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
111 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
112 return (TestResult.PASS, 0.0, "")
113 else:
114 tolerance = quantize_tolerance + 1
115 while not np.all(
116 np.absolute(reference_result - test_result) <= quantize_tolerance
117 ):
118 tolerance = tolerance + 1
119 if tolerance > 10:
120 break
121
122 if tolerance > 10:
123 msg = "Integer result does not match and is greater than 10 difference"
124 else:
125 msg = (
126 "Integer result does not match but is within {} difference".format(
127 tolerance
128 )
129 )
130 # Fall-through to below to add failure values
131
132 elif reference_result.dtype == bool:
133 assert test_result.dtype == bool
134 # All boolean values must match, xor will show up differences
135 test = np.array_equal(reference_result, test_result)
136 if np.all(test):
137 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
138 return (TestResult.PASS, 0.0, "")
139 msg = "Boolean result does not match"
140 tolerance = 0.0
141 # Fall-through to below to add failure values
142
143 elif reference_result.dtype == np.float32:
144 tolerance = float_tolerance
145 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
146 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
147 return (TestResult.PASS, tolerance, "")
148 msg = "Float result does not match within tolerance of {}".format(tolerance)
149 # Fall-through to below to add failure values
150
151 else:
152 print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
153 msg = "Unsupported results type: {}".format(reference_result.dtype)
154 return (TestResult.MISMATCH, 0.0, msg)
155
156 # Fall-through for mismatch failure to add values to msg
157 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
158 np.set_printoptions(threshold=128)
159 msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result)
160 msg = "{}\nreference_result: {}\n{}".format(
161 msg, reference_result.shape, reference_result
162 )
163 return (TestResult.MISMATCH, tolerance, msg)
164
165
166def main(argv=None):
167 """Check that the supplied reference and result files are the same."""
168 parser = argparse.ArgumentParser()
169 parser.add_argument(
170 "reference_path", type=Path, help="the path to the reference file to test"
171 )
172 parser.add_argument(
173 "result_path", type=Path, help="the path to the result file to test"
174 )
175 args = parser.parse_args(argv)
176 ref_path = args.reference_path
177 res_path = args.result_path
178
179 result, tolerance, msg = test_check(ref_path, res_path)
180 if result != TestResult.PASS:
181 print(msg)
182
183 return result
184
185
186if __name__ == "__main__":
187 exit(main())