blob: 66864c21daf3000b67ad5a2b4d1b4733e103c790 [file] [log] [blame]
"""TOSA result checker script."""
# Copyright (c) 2020-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
from enum import Enum
from enum import IntEnum
from enum import unique
from pathlib import Path
import numpy as np
##################################
color_printing = True
@unique
class LogColors(Enum):
"""Shell escape sequence colors for logging."""
NONE = "\u001b[0m"
GREEN = "\u001b[32;1m"
RED = "\u001b[31;1m"
YELLOW = "\u001b[33;1m"
BOLD_WHITE = "\u001b[1m"
def set_print_in_color(enabled):
"""Set color printing to enabled or disabled."""
global color_printing
color_printing = enabled
def print_color(color, msg):
"""Print color status messages if enabled."""
global color_printing
if not color_printing:
print(msg)
else:
print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
@unique
class TestResult(IntEnum):
"""Test result values."""
# Note: PASS must be 0 for command line return success
PASS = 0
MISSING_FILE = 1
INCORRECT_FORMAT = 2
MISMATCH = 3
INTERNAL_ERROR = 4
TestResultErrorStr = [
"",
"Missing file",
"Incorrect format",
"Mismatch",
"Internal error",
]
##################################
def test_check(
reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
):
"""Check if the result is the same as the expected reference."""
if not os.path.isfile(reference):
print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
msg = "Missing reference file: {}".format(reference)
return (TestResult.MISSING_FILE, 0.0, msg)
if not os.path.isfile(result):
print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
msg = "Missing result file: {}".format(result)
return (TestResult.MISSING_FILE, 0.0, msg)
try:
test_result = np.load(result)
except Exception as e:
print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e)
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
try:
reference_result = np.load(reference)
except Exception as e:
print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
reference, e
)
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
# Type comparison
if test_result.dtype != reference_result.dtype:
print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
msg = "Mismatch results type: Expected {}, got {}".format(
reference_result.dtype, test_result.dtype
)
return (TestResult.MISMATCH, 0.0, msg)
# Size comparison
# Size = 1 tensors can be equivalently represented as having rank 0 or rank
# >= 0, allow that special case
test_result = np.squeeze(test_result)
reference_result = np.squeeze(reference_result)
if np.shape(test_result) != np.shape(reference_result):
print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
msg = "Shapes mismatch: Reference {} vs {}".format(
np.shape(test_result), np.shape(reference_result)
)
return (TestResult.MISMATCH, 0.0, msg)
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, 0.0, "")
else:
tolerance = quantize_tolerance + 1
while not np.all(
np.absolute(reference_result - test_result) <= quantize_tolerance
):
tolerance = tolerance + 1
if tolerance > 10:
break
if tolerance > 10:
msg = "Integer result does not match and is greater than 10 difference"
else:
msg = (
"Integer result does not match but is within {} difference".format(
tolerance
)
)
# Fall-through to below to add failure values
elif reference_result.dtype == bool:
assert test_result.dtype == bool
# All boolean values must match, xor will show up differences
test = np.array_equal(reference_result, test_result)
if np.all(test):
print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, 0.0, "")
msg = "Boolean result does not match"
tolerance = 0.0
# Fall-through to below to add failure values
elif reference_result.dtype == np.float32:
tolerance = float_tolerance
if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, tolerance, "")
msg = "Float result does not match within tolerance of {}".format(tolerance)
# Fall-through to below to add failure values
else:
print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
msg = "Unsupported results type: {}".format(reference_result.dtype)
return (TestResult.MISMATCH, 0.0, msg)
# Fall-through for mismatch failure to add values to msg
print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
np.set_printoptions(threshold=128)
msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result)
msg = "{}\nreference_result: {}\n{}".format(
msg, reference_result.shape, reference_result
)
return (TestResult.MISMATCH, tolerance, msg)
def main(argv=None):
"""Check that the supplied reference and result files are the same."""
parser = argparse.ArgumentParser()
parser.add_argument(
"reference_path", type=Path, help="the path to the reference file to test"
)
parser.add_argument(
"result_path", type=Path, help="the path to the result file to test"
)
args = parser.parse_args(argv)
ref_path = args.reference_path
res_path = args.result_path
result, tolerance, msg = test_check(ref_path, res_path)
if result != TestResult.PASS:
print(msg)
return result
if __name__ == "__main__":
exit(main())