Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 1 | """TOSA result checker script.""" |
| 2 | # Copyright (c) 2020-2022, ARM Limited. |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | import argparse |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 5 | from enum import Enum |
| 6 | from enum import IntEnum |
| 7 | from enum import unique |
| 8 | from pathlib import Path |
| 9 | |
| 10 | import numpy as np |
James Ward | 24dbc42 | 2022-10-19 12:20:31 +0100 | [diff] [blame] | 11 | from generator.tosa_utils import float32_is_valid_bfloat16 |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 12 | |
| 13 | ################################## |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 14 | color_printing = True |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 15 | |
| 16 | |
| 17 | @unique |
| 18 | class LogColors(Enum): |
| 19 | """Shell escape sequence colors for logging.""" |
| 20 | |
| 21 | NONE = "\u001b[0m" |
| 22 | GREEN = "\u001b[32;1m" |
| 23 | RED = "\u001b[31;1m" |
| 24 | YELLOW = "\u001b[33;1m" |
| 25 | BOLD_WHITE = "\u001b[1m" |
| 26 | |
| 27 | |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 28 | def set_print_in_color(enabled): |
| 29 | """Set color printing to enabled or disabled.""" |
| 30 | global color_printing |
| 31 | color_printing = enabled |
| 32 | |
| 33 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 34 | def print_color(color, msg): |
| 35 | """Print color status messages if enabled.""" |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 36 | global color_printing |
| 37 | if not color_printing: |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 38 | print(msg) |
| 39 | else: |
| 40 | print("{}{}{}".format(color.value, msg, LogColors.NONE.value)) |
| 41 | |
| 42 | |
| 43 | @unique |
| 44 | class TestResult(IntEnum): |
| 45 | """Test result values.""" |
| 46 | |
| 47 | # Note: PASS must be 0 for command line return success |
| 48 | PASS = 0 |
| 49 | MISSING_FILE = 1 |
| 50 | INCORRECT_FORMAT = 2 |
| 51 | MISMATCH = 3 |
| 52 | INTERNAL_ERROR = 4 |
| 53 | |
| 54 | |
| 55 | TestResultErrorStr = [ |
| 56 | "", |
| 57 | "Missing file", |
| 58 | "Incorrect format", |
| 59 | "Mismatch", |
| 60 | "Internal error", |
| 61 | ] |
| 62 | ################################## |
| 63 | |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 64 | DEFAULT_FP_TOLERANCE = 1e-3 |
| 65 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 66 | |
| 67 | def test_check( |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 68 | reference_path, |
| 69 | result_path, |
James Ward | 24dbc42 | 2022-10-19 12:20:31 +0100 | [diff] [blame] | 70 | test_name="test", |
| 71 | quantize_tolerance=0, |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 72 | float_tolerance=DEFAULT_FP_TOLERANCE, |
James Ward | 24dbc42 | 2022-10-19 12:20:31 +0100 | [diff] [blame] | 73 | misc_checks=[], |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 74 | ): |
| 75 | """Check if the result is the same as the expected reference.""" |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 76 | if not reference_path.is_file(): |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 77 | print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name)) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 78 | msg = "Missing reference file: {}".format(reference_path) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 79 | return (TestResult.MISSING_FILE, 0.0, msg) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 80 | if not result_path.is_file(): |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 81 | print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name)) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 82 | msg = "Missing result file: {}".format(result_path) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 83 | return (TestResult.MISSING_FILE, 0.0, msg) |
| 84 | |
| 85 | try: |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 86 | test_result = np.load(result_path) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 87 | except Exception as e: |
| 88 | print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name)) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 89 | msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( |
| 90 | result_path, e |
| 91 | ) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 92 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 93 | try: |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 94 | reference_result = np.load(reference_path) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 95 | except Exception as e: |
| 96 | print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name)) |
| 97 | msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 98 | reference_path, e |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 99 | ) |
| 100 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 101 | |
| 102 | # Type comparison |
| 103 | if test_result.dtype != reference_result.dtype: |
| 104 | print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name)) |
| 105 | msg = "Mismatch results type: Expected {}, got {}".format( |
| 106 | reference_result.dtype, test_result.dtype |
| 107 | ) |
| 108 | return (TestResult.MISMATCH, 0.0, msg) |
| 109 | |
| 110 | # Size comparison |
| 111 | # Size = 1 tensors can be equivalently represented as having rank 0 or rank |
| 112 | # >= 0, allow that special case |
| 113 | test_result = np.squeeze(test_result) |
| 114 | reference_result = np.squeeze(reference_result) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 115 | difference = None |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 116 | |
| 117 | if np.shape(test_result) != np.shape(reference_result): |
| 118 | print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
| 119 | msg = "Shapes mismatch: Reference {} vs {}".format( |
| 120 | np.shape(test_result), np.shape(reference_result) |
| 121 | ) |
| 122 | return (TestResult.MISMATCH, 0.0, msg) |
| 123 | |
James Ward | 24dbc42 | 2022-10-19 12:20:31 +0100 | [diff] [blame] | 124 | # Perform miscellaneous checks |
| 125 | if "bf16" in misc_checks: |
| 126 | # Ensure floats are valid bfloat16 values |
| 127 | test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat]) |
| 128 | ref_res_is_bf16 = all( |
| 129 | [float32_is_valid_bfloat16(f) for f in reference_result.flat] |
| 130 | ) |
| 131 | if not (test_res_is_bf16 and ref_res_is_bf16): |
| 132 | msg = ( |
| 133 | "All output values must be valid bfloat16. " |
| 134 | "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" |
| 135 | ) |
| 136 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 137 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 138 | # for quantized test, allow +-(quantize_tolerance) error |
| 139 | if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: |
| 140 | |
| 141 | if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance): |
| 142 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 143 | return (TestResult.PASS, 0.0, "") |
| 144 | else: |
| 145 | tolerance = quantize_tolerance + 1 |
| 146 | while not np.all( |
| 147 | np.absolute(reference_result - test_result) <= quantize_tolerance |
| 148 | ): |
| 149 | tolerance = tolerance + 1 |
| 150 | if tolerance > 10: |
| 151 | break |
| 152 | |
| 153 | if tolerance > 10: |
| 154 | msg = "Integer result does not match and is greater than 10 difference" |
| 155 | else: |
| 156 | msg = ( |
| 157 | "Integer result does not match but is within {} difference".format( |
| 158 | tolerance |
| 159 | ) |
| 160 | ) |
| 161 | # Fall-through to below to add failure values |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 162 | difference = reference_result - test_result |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 163 | |
| 164 | elif reference_result.dtype == bool: |
| 165 | assert test_result.dtype == bool |
| 166 | # All boolean values must match, xor will show up differences |
| 167 | test = np.array_equal(reference_result, test_result) |
| 168 | if np.all(test): |
| 169 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 170 | return (TestResult.PASS, 0.0, "") |
| 171 | msg = "Boolean result does not match" |
| 172 | tolerance = 0.0 |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 173 | difference = None |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 174 | # Fall-through to below to add failure values |
| 175 | |
James Ward | 8b39043 | 2022-08-12 20:48:56 +0100 | [diff] [blame] | 176 | # TODO: update for fp16 tolerance |
| 177 | elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16: |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 178 | tolerance = float_tolerance |
| 179 | if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): |
| 180 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 181 | return (TestResult.PASS, tolerance, "") |
| 182 | msg = "Float result does not match within tolerance of {}".format(tolerance) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 183 | difference = reference_result - test_result |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 184 | # Fall-through to below to add failure values |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 185 | else: |
| 186 | print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) |
| 187 | msg = "Unsupported results type: {}".format(reference_result.dtype) |
| 188 | return (TestResult.MISMATCH, 0.0, msg) |
| 189 | |
| 190 | # Fall-through for mismatch failure to add values to msg |
| 191 | print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 192 | np.set_printoptions(threshold=128, edgeitems=2) |
| 193 | |
| 194 | if difference is not None: |
| 195 | tolerance_needed = np.amax(np.absolute(difference)) |
| 196 | msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed) |
| 197 | |
| 198 | msg = "{}\n>> reference_result: {}\n{}".format( |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 199 | msg, reference_result.shape, reference_result |
| 200 | ) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 201 | msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result) |
| 202 | |
| 203 | if difference is not None: |
| 204 | msg = "{}\n!! difference_result: \n{}".format(msg, difference) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 205 | return (TestResult.MISMATCH, tolerance, msg) |
| 206 | |
| 207 | |
| 208 | def main(argv=None): |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 209 | """Check that the supplied reference and result files have the same contents.""" |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 210 | parser = argparse.ArgumentParser() |
| 211 | parser.add_argument( |
| 212 | "reference_path", type=Path, help="the path to the reference file to test" |
| 213 | ) |
| 214 | parser.add_argument( |
| 215 | "result_path", type=Path, help="the path to the result file to test" |
| 216 | ) |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 217 | parser.add_argument( |
| 218 | "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance" |
| 219 | ) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 220 | args = parser.parse_args(argv) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 221 | |
Jeremy Johnson | e4b08ff | 2022-09-15 10:38:17 +0100 | [diff] [blame] | 222 | result, tolerance, msg = test_check( |
| 223 | args.reference_path, args.result_path, float_tolerance=args.fp_tolerance |
| 224 | ) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 225 | if result != TestResult.PASS: |
| 226 | print(msg) |
| 227 | |
| 228 | return result |
| 229 | |
| 230 | |
| 231 | if __name__ == "__main__": |
| 232 | exit(main()) |