blob: 66864c21daf3000b67ad5a2b4d1b4733e103c790 [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""TOSA result checker script."""
2# Copyright (c) 2020-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import argparse
5import os
6from enum import Enum
7from enum import IntEnum
8from enum import unique
9from pathlib import Path
10
11import numpy as np
12
13##################################
Jeremy Johnson015c3552022-02-23 12:15:03 +000014color_printing = True
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000015
16
17@unique
18class LogColors(Enum):
19 """Shell escape sequence colors for logging."""
20
21 NONE = "\u001b[0m"
22 GREEN = "\u001b[32;1m"
23 RED = "\u001b[31;1m"
24 YELLOW = "\u001b[33;1m"
25 BOLD_WHITE = "\u001b[1m"
26
27
Jeremy Johnson015c3552022-02-23 12:15:03 +000028def set_print_in_color(enabled):
29 """Set color printing to enabled or disabled."""
30 global color_printing
31 color_printing = enabled
32
33
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000034def print_color(color, msg):
35 """Print color status messages if enabled."""
Jeremy Johnson015c3552022-02-23 12:15:03 +000036 global color_printing
37 if not color_printing:
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000038 print(msg)
39 else:
40 print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
41
42
43@unique
44class TestResult(IntEnum):
45 """Test result values."""
46
47 # Note: PASS must be 0 for command line return success
48 PASS = 0
49 MISSING_FILE = 1
50 INCORRECT_FORMAT = 2
51 MISMATCH = 3
52 INTERNAL_ERROR = 4
53
54
55TestResultErrorStr = [
56 "",
57 "Missing file",
58 "Incorrect format",
59 "Mismatch",
60 "Internal error",
61]
62##################################
63
64
65def test_check(
66 reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
67):
68 """Check if the result is the same as the expected reference."""
69 if not os.path.isfile(reference):
70 print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
71 msg = "Missing reference file: {}".format(reference)
72 return (TestResult.MISSING_FILE, 0.0, msg)
73 if not os.path.isfile(result):
74 print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
75 msg = "Missing result file: {}".format(result)
76 return (TestResult.MISSING_FILE, 0.0, msg)
77
78 try:
79 test_result = np.load(result)
80 except Exception as e:
81 print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
82 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e)
83 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
84 try:
85 reference_result = np.load(reference)
86 except Exception as e:
87 print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
88 msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
89 reference, e
90 )
91 return (TestResult.INCORRECT_FORMAT, 0.0, msg)
92
93 # Type comparison
94 if test_result.dtype != reference_result.dtype:
95 print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
96 msg = "Mismatch results type: Expected {}, got {}".format(
97 reference_result.dtype, test_result.dtype
98 )
99 return (TestResult.MISMATCH, 0.0, msg)
100
101 # Size comparison
102 # Size = 1 tensors can be equivalently represented as having rank 0 or rank
103 # >= 0, allow that special case
104 test_result = np.squeeze(test_result)
105 reference_result = np.squeeze(reference_result)
106
107 if np.shape(test_result) != np.shape(reference_result):
108 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
109 msg = "Shapes mismatch: Reference {} vs {}".format(
110 np.shape(test_result), np.shape(reference_result)
111 )
112 return (TestResult.MISMATCH, 0.0, msg)
113
114 # for quantized test, allow +-(quantize_tolerance) error
115 if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
116
117 if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
118 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
119 return (TestResult.PASS, 0.0, "")
120 else:
121 tolerance = quantize_tolerance + 1
122 while not np.all(
123 np.absolute(reference_result - test_result) <= quantize_tolerance
124 ):
125 tolerance = tolerance + 1
126 if tolerance > 10:
127 break
128
129 if tolerance > 10:
130 msg = "Integer result does not match and is greater than 10 difference"
131 else:
132 msg = (
133 "Integer result does not match but is within {} difference".format(
134 tolerance
135 )
136 )
137 # Fall-through to below to add failure values
138
139 elif reference_result.dtype == bool:
140 assert test_result.dtype == bool
141 # All boolean values must match, xor will show up differences
142 test = np.array_equal(reference_result, test_result)
143 if np.all(test):
144 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
145 return (TestResult.PASS, 0.0, "")
146 msg = "Boolean result does not match"
147 tolerance = 0.0
148 # Fall-through to below to add failure values
149
150 elif reference_result.dtype == np.float32:
151 tolerance = float_tolerance
152 if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
153 print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
154 return (TestResult.PASS, tolerance, "")
155 msg = "Float result does not match within tolerance of {}".format(tolerance)
156 # Fall-through to below to add failure values
157
158 else:
159 print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
160 msg = "Unsupported results type: {}".format(reference_result.dtype)
161 return (TestResult.MISMATCH, 0.0, msg)
162
163 # Fall-through for mismatch failure to add values to msg
164 print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
165 np.set_printoptions(threshold=128)
166 msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result)
167 msg = "{}\nreference_result: {}\n{}".format(
168 msg, reference_result.shape, reference_result
169 )
170 return (TestResult.MISMATCH, tolerance, msg)
171
172
173def main(argv=None):
174 """Check that the supplied reference and result files are the same."""
175 parser = argparse.ArgumentParser()
176 parser.add_argument(
177 "reference_path", type=Path, help="the path to the reference file to test"
178 )
179 parser.add_argument(
180 "result_path", type=Path, help="the path to the result file to test"
181 )
182 args = parser.parse_args(argv)
183 ref_path = args.reference_path
184 res_path = args.result_path
185
186 result, tolerance, msg = test_check(ref_path, res_path)
187 if result != TestResult.PASS:
188 print(msg)
189
190 return result
191
192
193if __name__ == "__main__":
194 exit(main())