+"""TOSA result checker script."""
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import os
+from enum import Enum
+from enum import IntEnum
+from enum import unique
+from pathlib import Path
+import numpy as np
+no_color_printing = False
+class LogColors(Enum):
+    """Shell escape sequence colors for logging."""
+    NONE = "\u001b[0m"
+    GREEN = "\u001b[32;1m"
+    RED = "\u001b[31;1m"
+    YELLOW = "\u001b[33;1m"
+    BOLD_WHITE = "\u001b[1m"
+def print_color(color, msg):
+    """Print color status messages if enabled."""
+    if no_color_printing:
+        print(msg)
+    else:
+        print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
+class TestResult(IntEnum):
+    """Test result values."""
+    # Note: PASS must be 0 for command line return success
+    PASS = 0
+    MISMATCH = 3
+TestResultErrorStr = [
+    "",
+    "Missing file",
+    "Incorrect format",
+    "Mismatch",
+    "Internal error",
+def test_check(
+    reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
+    """Check if the result is the same as the expected reference."""
+    if not os.path.isfile(reference):
+        print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
+        msg = "Missing reference file: {}".format(reference)
+        return (TestResult.MISSING_FILE, 0.0, msg)
+    if not os.path.isfile(result):
+        print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
+        msg = "Missing result file: {}".format(result)
+        return (TestResult.MISSING_FILE, 0.0, msg)
+    try:
+        test_result = np.load(result)
+    except Exception as e:
+        print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
+        msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e)
+        return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+    try:
+        reference_result = np.load(reference)
+    except Exception as e:
+        print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
+        msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
+            reference, e
+        )
+        return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+    # Type comparison
+    if test_result.dtype != reference_result.dtype:
+        print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
+        msg = "Mismatch results type: Expected {}, got {}".format(
+            reference_result.dtype, test_result.dtype
+        )
+        return (TestResult.MISMATCH, 0.0, msg)
+    # Size comparison
+    # Size = 1 tensors can be equivalently represented as having rank 0 or rank
+    # >= 0, allow that special case
+    test_result = np.squeeze(test_result)
+    reference_result = np.squeeze(reference_result)
+    if np.shape(test_result) != np.shape(reference_result):
+        print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+        msg = "Shapes mismatch: Reference {} vs {}".format(
+            np.shape(test_result), np.shape(reference_result)
+        )
+        return (TestResult.MISMATCH, 0.0, msg)
+    # for quantized test, allow +-(quantize_tolerance) error
+    if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
+        if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
+            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+            return (TestResult.PASS, 0.0, "")
+        else:
+            tolerance = quantize_tolerance + 1
+            while not np.all(
+                np.absolute(reference_result - test_result) <= quantize_tolerance
+            ):
+                tolerance = tolerance + 1
+                if tolerance > 10:
+                    break
+            if tolerance > 10:
+                msg = "Integer result does not match and is greater than 10 difference"
+            else:
+                msg = (
+                    "Integer result does not match but is within {} difference".format(
+                        tolerance
+                    )
+                )
+            # Fall-through to below to add failure values
+    elif reference_result.dtype == bool:
+        assert test_result.dtype == bool
+        # All boolean values must match, xor will show up differences
+        test = np.array_equal(reference_result, test_result)
+        if np.all(test):
+            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+            return (TestResult.PASS, 0.0, "")
+        msg = "Boolean result does not match"
+        tolerance = 0.0
+        # Fall-through to below to add failure values
+    elif reference_result.dtype == np.float32:
+        tolerance = float_tolerance
+        if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
+            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+            return (TestResult.PASS, tolerance, "")
+        msg = "Float result does not match within tolerance of {}".format(tolerance)
+        # Fall-through to below to add failure values
+    else:
+        print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
+        msg = "Unsupported results type: {}".format(reference_result.dtype)
+        return (TestResult.MISMATCH, 0.0, msg)
+    # Fall-through for mismatch failure to add values to msg
+    print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+    np.set_printoptions(threshold=128)
+    msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result)
+    msg = "{}\nreference_result: {}\n{}".format(
+        msg, reference_result.shape, reference_result
+    )
+    return (TestResult.MISMATCH, tolerance, msg)
+def main(argv=None):
+    """Check that the supplied reference and result files are the same."""
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "reference_path", type=Path, help="the path to the reference file to test"
+    )
+    parser.add_argument(
+        "result_path", type=Path, help="the path to the result file to test"
+    )
+    args = parser.parse_args(argv)
+    ref_path = args.reference_path
+    res_path = args.result_path
+    result, tolerance, msg = test_check(ref_path, res_path)
+    if result != TestResult.PASS:
+        print(msg)
+    return result
+if __name__ == "__main__":
+    exit(main())