| """TOSA result checker script.""" |
| # Copyright (c) 2020-2022, ARM Limited. |
| # SPDX-License-Identifier: Apache-2.0 |
| import argparse |
| import os |
| from enum import Enum |
| from enum import IntEnum |
| from enum import unique |
| from pathlib import Path |
| |
| import numpy as np |
| |
| ################################## |
| color_printing = True |
| |
| |
| @unique |
| class LogColors(Enum): |
| """Shell escape sequence colors for logging.""" |
| |
| NONE = "\u001b[0m" |
| GREEN = "\u001b[32;1m" |
| RED = "\u001b[31;1m" |
| YELLOW = "\u001b[33;1m" |
| BOLD_WHITE = "\u001b[1m" |
| |
| |
| def set_print_in_color(enabled): |
| """Set color printing to enabled or disabled.""" |
| global color_printing |
| color_printing = enabled |
| |
| |
| def print_color(color, msg): |
| """Print color status messages if enabled.""" |
| global color_printing |
| if not color_printing: |
| print(msg) |
| else: |
| print("{}{}{}".format(color.value, msg, LogColors.NONE.value)) |
| |
| |
| @unique |
| class TestResult(IntEnum): |
| """Test result values.""" |
| |
| # Note: PASS must be 0 for command line return success |
| PASS = 0 |
| MISSING_FILE = 1 |
| INCORRECT_FORMAT = 2 |
| MISMATCH = 3 |
| INTERNAL_ERROR = 4 |
| |
| |
| TestResultErrorStr = [ |
| "", |
| "Missing file", |
| "Incorrect format", |
| "Mismatch", |
| "Internal error", |
| ] |
| ################################## |
| |
| |
| def test_check( |
| reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3 |
| ): |
| """Check if the result is the same as the expected reference.""" |
| if not os.path.isfile(reference): |
| print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name)) |
| msg = "Missing reference file: {}".format(reference) |
| return (TestResult.MISSING_FILE, 0.0, msg) |
| if not os.path.isfile(result): |
| print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name)) |
| msg = "Missing result file: {}".format(result) |
| return (TestResult.MISSING_FILE, 0.0, msg) |
| |
| try: |
| test_result = np.load(result) |
| except Exception as e: |
| print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name)) |
| msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e) |
| return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| try: |
| reference_result = np.load(reference) |
| except Exception as e: |
| print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name)) |
| msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( |
| reference, e |
| ) |
| return (TestResult.INCORRECT_FORMAT, 0.0, msg) |
| |
| # Type comparison |
| if test_result.dtype != reference_result.dtype: |
| print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name)) |
| msg = "Mismatch results type: Expected {}, got {}".format( |
| reference_result.dtype, test_result.dtype |
| ) |
| return (TestResult.MISMATCH, 0.0, msg) |
| |
| # Size comparison |
| # Size = 1 tensors can be equivalently represented as having rank 0 or rank |
| # >= 0, allow that special case |
| test_result = np.squeeze(test_result) |
| reference_result = np.squeeze(reference_result) |
| |
| if np.shape(test_result) != np.shape(reference_result): |
| print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
| msg = "Shapes mismatch: Reference {} vs {}".format( |
| np.shape(test_result), np.shape(reference_result) |
| ) |
| return (TestResult.MISMATCH, 0.0, msg) |
| |
| # for quantized test, allow +-(quantize_tolerance) error |
| if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: |
| |
| if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance): |
| print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| return (TestResult.PASS, 0.0, "") |
| else: |
| tolerance = quantize_tolerance + 1 |
| while not np.all( |
| np.absolute(reference_result - test_result) <= quantize_tolerance |
| ): |
| tolerance = tolerance + 1 |
| if tolerance > 10: |
| break |
| |
| if tolerance > 10: |
| msg = "Integer result does not match and is greater than 10 difference" |
| else: |
| msg = ( |
| "Integer result does not match but is within {} difference".format( |
| tolerance |
| ) |
| ) |
| # Fall-through to below to add failure values |
| |
| elif reference_result.dtype == bool: |
| assert test_result.dtype == bool |
| # All boolean values must match, xor will show up differences |
| test = np.array_equal(reference_result, test_result) |
| if np.all(test): |
| print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| return (TestResult.PASS, 0.0, "") |
| msg = "Boolean result does not match" |
| tolerance = 0.0 |
| # Fall-through to below to add failure values |
| |
| elif reference_result.dtype == np.float32: |
| tolerance = float_tolerance |
| if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): |
| print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) |
| return (TestResult.PASS, tolerance, "") |
| msg = "Float result does not match within tolerance of {}".format(tolerance) |
| # Fall-through to below to add failure values |
| |
| else: |
| print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) |
| msg = "Unsupported results type: {}".format(reference_result.dtype) |
| return (TestResult.MISMATCH, 0.0, msg) |
| |
| # Fall-through for mismatch failure to add values to msg |
| print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) |
| np.set_printoptions(threshold=128) |
| msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result) |
| msg = "{}\nreference_result: {}\n{}".format( |
| msg, reference_result.shape, reference_result |
| ) |
| return (TestResult.MISMATCH, tolerance, msg) |
| |
| |
| def main(argv=None): |
| """Check that the supplied reference and result files are the same.""" |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "reference_path", type=Path, help="the path to the reference file to test" |
| ) |
| parser.add_argument( |
| "result_path", type=Path, help="the path to the result file to test" |
| ) |
| args = parser.parse_args(argv) |
| ref_path = args.reference_path |
| res_path = args.result_path |
| |
| result, tolerance, msg = test_check(ref_path, res_path) |
| if result != TestResult.PASS: |
| print(msg) |
| |
| return result |
| |
| |
| if __name__ == "__main__": |
| exit(main()) |