Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 1 | """TOSA result checker script.""" |
| 2 | # Copyright (c) 2020-2022, ARM Limited. |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | import argparse |
| 5 | import os |
| 6 | from enum import Enum |
| 7 | from enum import IntEnum |
| 8 | from enum import unique |
| 9 | from pathlib import Path |
| 10 | |
| 11 | import numpy as np |
| 12 | |
| 13 | ################################## |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 14 | color_printing = True |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 15 | |
| 16 | |
| 17 | @unique |
| 18 | class LogColors(Enum): |
| 19 | """Shell escape sequence colors for logging.""" |
| 20 | |
| 21 | NONE = "\u001b[0m" |
| 22 | GREEN = "\u001b[32;1m" |
| 23 | RED = "\u001b[31;1m" |
| 24 | YELLOW = "\u001b[33;1m" |
| 25 | BOLD_WHITE = "\u001b[1m" |
| 26 | |
| 27 | |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 28 | def set_print_in_color(enabled): |
| 29 | """Set color printing to enabled or disabled.""" |
| 30 | global color_printing |
| 31 | color_printing = enabled |
| 32 | |
| 33 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 34 | def print_color(color, msg): |
| 35 | """Print color status messages if enabled.""" |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 36 | global color_printing |
| 37 | if not color_printing: |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 38 | print(msg) |
| 39 | else: |
| 40 | print("{}{}{}".format(color.value, msg, LogColors.NONE.value)) |
| 41 | |
| 42 | |
| 43 | @unique |
| 44 | class TestResult(IntEnum): |
| 45 | """Test result values.""" |
| 46 | |
| 47 | # Note: PASS must be 0 for command line return success |
| 48 | PASS = 0 |
| 49 | MISSING_FILE = 1 |
| 50 | INCORRECT_FORMAT = 2 |
| 51 | MISMATCH = 3 |
| 52 | INTERNAL_ERROR = 4 |
| 53 | |
| 54 | |
| 55 | TestResultErrorStr = [ |
| 56 | "", |
| 57 | "Missing file", |
| 58 | "Incorrect format", |
| 59 | "Mismatch", |
| 60 | "Internal error", |
| 61 | ] |
| 62 | ################################## |
| 63 | |
| 64 | |
| 65 | def test_check( |
| 66 | reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3 |
| 67 | ): |
| 68 | """Check if the result is the same as the expected reference.""" |
| 69 | if not os.path.isfile(reference): |
| 70 | print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name)) |
| 71 | msg = "Missing reference file: {}".format(reference) |
| 72 | return (TestResult.MISSING_FILE, 0.0, msg) |
| 73 | if not os.path.isfile(result): |
| 74 | print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name)) |
| 75 | msg = "Missing result file: {}".format(result) |
| 76 | return (TestResult.MISSING_FILE, 0.0, msg) |
| 77 | |
| 78 | try: |
| 79 | test_result = np.load(result) |
| 80 | except Exception as e: |
| 81 | print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name)) |
| 82 | msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e) |
| 83 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 84 | try: |
| 85 | reference_result = np.load(reference) |
| 86 | except Exception as e: |
| 87 | print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name)) |
| 88 | msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( |
| 89 | reference, e |
| 90 | ) |
| 91 | return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| 92 | |
| 93 | # Type comparison |
| 94 | if test_result.dtype != reference_result.dtype: |
| 95 | print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name)) |
| 96 | msg = "Mismatch results type: Expected {}, got {}".format( |
| 97 | reference_result.dtype, test_result.dtype |
| 98 | ) |
| 99 | return (TestResult.MISMATCH, 0.0, msg) |
| 100 | |
| 101 | # Size comparison |
| 102 | # Size = 1 tensors can be equivalently represented as having rank 0 or rank |
| 103 | # >= 0, allow that special case |
| 104 | test_result = np.squeeze(test_result) |
| 105 | reference_result = np.squeeze(reference_result) |
| 106 | |
| 107 | if np.shape(test_result) != np.shape(reference_result): |
| 108 | print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
| 109 | msg = "Shapes mismatch: Reference {} vs {}".format( |
| 110 | np.shape(test_result), np.shape(reference_result) |
| 111 | ) |
| 112 | return (TestResult.MISMATCH, 0.0, msg) |
| 113 | |
| 114 | # for quantized test, allow +-(quantize_tolerance) error |
| 115 | if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: |
| 116 | |
| 117 | if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance): |
| 118 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 119 | return (TestResult.PASS, 0.0, "") |
| 120 | else: |
| 121 | tolerance = quantize_tolerance + 1 |
| 122 | while not np.all( |
| 123 | np.absolute(reference_result - test_result) <= quantize_tolerance |
| 124 | ): |
| 125 | tolerance = tolerance + 1 |
| 126 | if tolerance > 10: |
| 127 | break |
| 128 | |
| 129 | if tolerance > 10: |
| 130 | msg = "Integer result does not match and is greater than 10 difference" |
| 131 | else: |
| 132 | msg = ( |
| 133 | "Integer result does not match but is within {} difference".format( |
| 134 | tolerance |
| 135 | ) |
| 136 | ) |
| 137 | # Fall-through to below to add failure values |
| 138 | |
| 139 | elif reference_result.dtype == bool: |
| 140 | assert test_result.dtype == bool |
| 141 | # All boolean values must match, xor will show up differences |
| 142 | test = np.array_equal(reference_result, test_result) |
| 143 | if np.all(test): |
| 144 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 145 | return (TestResult.PASS, 0.0, "") |
| 146 | msg = "Boolean result does not match" |
| 147 | tolerance = 0.0 |
| 148 | # Fall-through to below to add failure values |
| 149 | |
| 150 | elif reference_result.dtype == np.float32: |
| 151 | tolerance = float_tolerance |
| 152 | if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): |
| 153 | print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| 154 | return (TestResult.PASS, tolerance, "") |
| 155 | msg = "Float result does not match within tolerance of {}".format(tolerance) |
| 156 | # Fall-through to below to add failure values |
| 157 | |
| 158 | else: |
| 159 | print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) |
| 160 | msg = "Unsupported results type: {}".format(reference_result.dtype) |
| 161 | return (TestResult.MISMATCH, 0.0, msg) |
| 162 | |
| 163 | # Fall-through for mismatch failure to add values to msg |
| 164 | print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
| 165 | np.set_printoptions(threshold=128) |
| 166 | msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result) |
| 167 | msg = "{}\nreference_result: {}\n{}".format( |
| 168 | msg, reference_result.shape, reference_result |
| 169 | ) |
| 170 | return (TestResult.MISMATCH, tolerance, msg) |
| 171 | |
| 172 | |
| 173 | def main(argv=None): |
| 174 | """Check that the supplied reference and result files are the same.""" |
| 175 | parser = argparse.ArgumentParser() |
| 176 | parser.add_argument( |
| 177 | "reference_path", type=Path, help="the path to the reference file to test" |
| 178 | ) |
| 179 | parser.add_argument( |
| 180 | "result_path", type=Path, help="the path to the result file to test" |
| 181 | ) |
| 182 | args = parser.parse_args(argv) |
| 183 | ref_path = args.reference_path |
| 184 | res_path = args.result_path |
| 185 | |
| 186 | result, tolerance, msg = test_check(ref_path, res_path) |
| 187 | if result != TestResult.PASS: |
| 188 | print(msg) |
| 189 | |
| 190 | return result |
| 191 | |
| 192 | |
| 193 | if __name__ == "__main__": |
| 194 | exit(main()) |