Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 1 | """TOSA result checker script.""" |
| 2 | # Copyright (c) 2020-2022, ARM Limited. |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | import argparse |
| 5 | import os |
| 6 | from enum import Enum |
| 7 | from enum import IntEnum |
| 8 | from enum import unique |
| 9 | from pathlib import Path |
| 10 | |
| 11 | import numpy as np |
James Ward | 24dbc42 | 2022-10-19 12:20:31 +0100 | [diff] [blame^] | 12 | from generator.tosa_utils import float32_is_valid_bfloat16 |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 13 | |
| 14 | ################################## |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 15 | color_printing = True |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 16 | |
| 17 | |
| 18 | @unique |
| 19 | class LogColors(Enum): |
| 20 | """Shell escape sequence colors for logging.""" |
| 21 | |
| 22 | NONE = "\u001b[0m" |
| 23 | GREEN = "\u001b[32;1m" |
| 24 | RED = "\u001b[31;1m" |
| 25 | YELLOW = "\u001b[33;1m" |
| 26 | BOLD_WHITE = "\u001b[1m" |
| 27 | |
| 28 | |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 29 | def set_print_in_color(enabled): |
| 30 | """Set color printing to enabled or disabled.""" |
| 31 | global color_printing |
| 32 | color_printing = enabled |
| 33 | |
| 34 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 35 | def print_color(color, msg): |
| 36 | """Print color status messages if enabled.""" |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 37 | global color_printing |
| 38 | if not color_printing: |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 39 | print(msg) |
| 40 | else: |
| 41 | print("{}{}{}".format(color.value, msg, LogColors.NONE.value)) |
| 42 | |
| 43 | |
| 44 | @unique |
| 45 | class TestResult(IntEnum): |
| 46 | """Test result values.""" |
| 47 | |
| 48 | # Note: PASS must be 0 for command line return success |
| 49 | PASS = 0 |
| 50 | MISSING_FILE = 1 |
| 51 | INCORRECT_FORMAT = 2 |
| 52 | MISMATCH = 3 |
| 53 | INTERNAL_ERROR = 4 |
| 54 | |
| 55 | |
| 56 | TestResultErrorStr = [ |
| 57 | "", |
| 58 | "Missing file", |
| 59 | "Incorrect format", |
| 60 | "Mismatch", |
| 61 | "Internal error", |
| 62 | ] |
| 63 | ################################## |
| 64 | |
| 65 | |
| 66 | def test_check( |
James Ward | 24dbc42 | 2022-10-19 12:20:31 +0100 | [diff] [blame^] | 67 | reference, |
| 68 | result, |
| 69 | test_name="test", |
| 70 | quantize_tolerance=0, |
| 71 | float_tolerance=1e-3, |
| 72 | misc_checks=[], |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 73 | ): |
| 74 | """Check if the result is the same as the expected reference.""" |
| 75 | if not os.path.isfile(reference): |
| 76 | print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name)) |
| 77 | msg = "Missing reference file: {}".format(reference) |
| 78 | return (TestResult.MISSING_FILE, 0.0, msg) |
| 79 | if not os.path.isfile(result): |
| 80 | print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name)) |
| 81 | msg = "Missing result file: {}".format(result) |
| 82 | return (TestResult.MISSING_FILE, 0.0, msg) |
| 83 | |
| 84 | try: |
| 85 | test_result = np.load(result) |
| 86 | except Exception as e: |
| 87 | print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name)) |
| 88 | msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e) |
| 89 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 90 | try: |
| 91 | reference_result = np.load(reference) |
| 92 | except Exception as e: |
| 93 | print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name)) |
| 94 | msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( |
| 95 | reference, e |
| 96 | ) |
| 97 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 98 | |
| 99 | # Type comparison |
| 100 | if test_result.dtype != reference_result.dtype: |
| 101 | print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name)) |
| 102 | msg = "Mismatch results type: Expected {}, got {}".format( |
| 103 | reference_result.dtype, test_result.dtype |
| 104 | ) |
| 105 | return (TestResult.MISMATCH, 0.0, msg) |
| 106 | |
| 107 | # Size comparison |
| 108 | # Size = 1 tensors can be equivalently represented as having rank 0 or rank |
| 109 | # >= 0, allow that special case |
| 110 | test_result = np.squeeze(test_result) |
| 111 | reference_result = np.squeeze(reference_result) |
| 112 | |
| 113 | if np.shape(test_result) != np.shape(reference_result): |
| 114 | print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
| 115 | msg = "Shapes mismatch: Reference {} vs {}".format( |
| 116 | np.shape(test_result), np.shape(reference_result) |
| 117 | ) |
| 118 | return (TestResult.MISMATCH, 0.0, msg) |
| 119 | |
James Ward | 24dbc42 | 2022-10-19 12:20:31 +0100 | [diff] [blame^] | 120 | # Perform miscellaneous checks |
| 121 | if "bf16" in misc_checks: |
| 122 | # Ensure floats are valid bfloat16 values |
| 123 | test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat]) |
| 124 | ref_res_is_bf16 = all( |
| 125 | [float32_is_valid_bfloat16(f) for f in reference_result.flat] |
| 126 | ) |
| 127 | if not (test_res_is_bf16 and ref_res_is_bf16): |
| 128 | msg = ( |
| 129 | "All output values must be valid bfloat16. " |
| 130 | "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" |
| 131 | ) |
| 132 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 133 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 134 | # for quantized test, allow +-(quantize_tolerance) error |
| 135 | if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: |
| 136 | |
| 137 | if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance): |
| 138 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 139 | return (TestResult.PASS, 0.0, "") |
| 140 | else: |
| 141 | tolerance = quantize_tolerance + 1 |
| 142 | while not np.all( |
| 143 | np.absolute(reference_result - test_result) <= quantize_tolerance |
| 144 | ): |
| 145 | tolerance = tolerance + 1 |
| 146 | if tolerance > 10: |
| 147 | break |
| 148 | |
| 149 | if tolerance > 10: |
| 150 | msg = "Integer result does not match and is greater than 10 difference" |
| 151 | else: |
| 152 | msg = ( |
| 153 | "Integer result does not match but is within {} difference".format( |
| 154 | tolerance |
| 155 | ) |
| 156 | ) |
| 157 | # Fall-through to below to add failure values |
| 158 | |
| 159 | elif reference_result.dtype == bool: |
| 160 | assert test_result.dtype == bool |
| 161 | # All boolean values must match, xor will show up differences |
| 162 | test = np.array_equal(reference_result, test_result) |
| 163 | if np.all(test): |
| 164 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 165 | return (TestResult.PASS, 0.0, "") |
| 166 | msg = "Boolean result does not match" |
| 167 | tolerance = 0.0 |
| 168 | # Fall-through to below to add failure values |
| 169 | |
James Ward | 8b39043 | 2022-08-12 20:48:56 +0100 | [diff] [blame] | 170 | # TODO: update for fp16 tolerance |
| 171 | elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16: |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 172 | tolerance = float_tolerance |
| 173 | if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): |
| 174 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 175 | return (TestResult.PASS, tolerance, "") |
| 176 | msg = "Float result does not match within tolerance of {}".format(tolerance) |
| 177 | # Fall-through to below to add failure values |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 178 | else: |
| 179 | print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) |
| 180 | msg = "Unsupported results type: {}".format(reference_result.dtype) |
| 181 | return (TestResult.MISMATCH, 0.0, msg) |
| 182 | |
| 183 | # Fall-through for mismatch failure to add values to msg |
| 184 | print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
| 185 | np.set_printoptions(threshold=128) |
| 186 | msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result) |
| 187 | msg = "{}\nreference_result: {}\n{}".format( |
| 188 | msg, reference_result.shape, reference_result |
| 189 | ) |
| 190 | return (TestResult.MISMATCH, tolerance, msg) |
| 191 | |
| 192 | |
| 193 | def main(argv=None): |
| 194 | """Check that the supplied reference and result files are the same.""" |
| 195 | parser = argparse.ArgumentParser() |
| 196 | parser.add_argument( |
| 197 | "reference_path", type=Path, help="the path to the reference file to test" |
| 198 | ) |
| 199 | parser.add_argument( |
| 200 | "result_path", type=Path, help="the path to the result file to test" |
| 201 | ) |
| 202 | args = parser.parse_args(argv) |
| 203 | ref_path = args.reference_path |
| 204 | res_path = args.result_path |
| 205 | |
| 206 | result, tolerance, msg = test_check(ref_path, res_path) |
| 207 | if result != TestResult.PASS: |
| 208 | print(msg) |
| 209 | |
| 210 | return result |
| 211 | |
| 212 | |
| 213 | if __name__ == "__main__": |
| 214 | exit(main()) |