Support for compliance checking testing

Updated to conformance generator to not generate tests with results for
compliance tests.
Updated test runner to run compliance mode version (precise & abs mode)
of reference model to create test results to use against SUT results.
Updated reference model to enable abs_mode on correct desc.json flags.
Updated test checker to support compliance checking using verifier lib.
Seperated color printing from test checker.

Change-Id: I7e2fbfc6883916caa5d94d4ece122c48bf45f530
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 1169a95..38ed510 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -1,43 +1,19 @@
 """TOSA result checker script."""
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import argparse
-from enum import Enum
+import json
 from enum import IntEnum
 from enum import unique
 from pathlib import Path
 
 import numpy as np
+from checker.color_print import LogColors
+from checker.color_print import print_color
+from checker.verifier import VerifierError
+from checker.verifier import VerifierLibrary
 from generator.tosa_utils import float32_is_valid_bfloat16
-
-##################################
-color_printing = True
-
-
-@unique
-class LogColors(Enum):
-    """Shell escape sequence colors for logging."""
-
-    NONE = "\u001b[0m"
-    GREEN = "\u001b[32;1m"
-    RED = "\u001b[31;1m"
-    YELLOW = "\u001b[33;1m"
-    BOLD_WHITE = "\u001b[1m"
-
-
-def set_print_in_color(enabled):
-    """Set color printing to enabled or disabled."""
-    global color_printing
-    color_printing = enabled
-
-
-def print_color(color, msg):
-    """Print color status messages if enabled."""
-    global color_printing
-    if not color_printing:
-        print(msg)
-    else:
-        print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
+from schemavalidation.schemavalidation import TestDescSchemaValidator
 
 
 @unique
@@ -62,46 +38,120 @@
 ##################################
 
 DEFAULT_FP_TOLERANCE = 1e-3
+result_printing = True
+
+
+def set_print_result(enabled):
+    """Set whether to print out or not."""
+    global result_printing
+    result_printing = enabled
+
+
+def _print_result(color, msg):
+    """Print out result."""
+    global result_printing
+    if result_printing:
+        print_color(color, msg)
+
+
+def compliance_check(
+    imp_result_path,
+    ref_result_path,
+    bnd_result_path,
+    test_name,
+    compliance_config,
+    ofm_name,
+    verify_lib_path,
+):
+    try:
+        vlib = VerifierLibrary(verify_lib_path)
+    except VerifierError as e:
+        _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}")
+        msg = f"Could not load verfier library: {str(e)}"
+        return (TestResult.INTERNAL_ERROR, 0.0, msg)
+
+    success = vlib.verify_data(
+        ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path
+    )
+    if success:
+        _print_result(LogColors.GREEN, f"Results PASS {test_name}")
+        return (TestResult.PASS, 0.0, "")
+    else:
+        _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
+        return (TestResult.MISMATCH, 0.0, "Non-compliance implementation results found")
 
 
 def test_check(
-    reference_path,
-    result_path,
-    test_name="test",
+    ref_result_path,
+    imp_result_path,
+    test_name=None,
     quantize_tolerance=0,
     float_tolerance=DEFAULT_FP_TOLERANCE,
     misc_checks=[],
+    test_desc=None,
+    bnd_result_path=None,
+    ofm_name=None,
+    verify_lib_path=None,
 ):
     """Check if the result is the same as the expected reference."""
-    if not reference_path.is_file():
-        print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
-        msg = "Missing reference file: {}".format(reference_path)
-        return (TestResult.MISSING_FILE, 0.0, msg)
-    if not result_path.is_file():
-        print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
-        msg = "Missing result file: {}".format(result_path)
-        return (TestResult.MISSING_FILE, 0.0, msg)
+    if test_desc:
+        # New compliance method - first get test details
+        try:
+            TestDescSchemaValidator().validate_config(test_desc)
+        except Exception as e:
+            _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}")
+            msg = f"Incorrect test format: {e}"
+            return (TestResult.INCORRECT_FORMAT, 0.0, msg)
 
-    try:
-        test_result = np.load(result_path)
-    except Exception as e:
-        print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
-        msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
-            result_path, e
+    if test_name is None:
+        test_name = "test"
+
+    paths = [imp_result_path, ref_result_path, bnd_result_path]
+    names = ["Implementation", "Reference", "Bounds"]
+    arrays = [None, None, None]
+
+    # Check the files exist and are in the right format
+    for idx, path in enumerate(paths):
+        name = names[idx]
+        if path is None and name == "Bounds":
+            # Bounds can be None - skip it
+            continue
+        if not path.is_file():
+            _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}")
+            msg = f"Missing {name} file: {str(path)}"
+            return (TestResult.MISSING_FILE, 0.0, msg)
+        try:
+            arrays[idx] = np.load(path)
+        except Exception as e:
+            _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}")
+            msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}"
+            return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
+    if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]:
+        # Switch to using the verifier library for full compliance
+        if ofm_name is None:
+            ofm_name = test_desc["ofm_name"][0]
+            if len(test_desc["ofm_name"]) > 1:
+                _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}")
+                msg = "Must specify output name (ofm_name) to check as multiple found in desc.json"
+                return (TestResult.MISSING_FILE, 0.0, msg)
+
+        compliance_json = test_desc["meta"]["compliance"]
+
+        return compliance_check(
+            *arrays,
+            test_name,
+            compliance_json,
+            ofm_name,
+            verify_lib_path,
         )
-        return (TestResult.INCORRECT_FORMAT, 0.0, msg)
-    try:
-        reference_result = np.load(reference_path)
-    except Exception as e:
-        print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
-        msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
-            reference_path, e
-        )
-        return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
+    # Else continue with original checking method
+    test_result, reference_result, _ = arrays
 
     # Type comparison
     if test_result.dtype != reference_result.dtype:
-        print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
+        _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
         msg = "Mismatch results type: Expected {}, got {}".format(
             reference_result.dtype, test_result.dtype
         )
@@ -115,7 +165,7 @@
     difference = None
 
     if np.shape(test_result) != np.shape(reference_result):
-        print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+        _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
         msg = "Shapes mismatch: Reference {} vs {}".format(
             np.shape(test_result), np.shape(reference_result)
         )
@@ -139,7 +189,7 @@
     if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
 
         if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
-            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+            _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
             return (TestResult.PASS, 0.0, "")
         else:
             tolerance = quantize_tolerance + 1
@@ -166,7 +216,7 @@
         # All boolean values must match, xor will show up differences
         test = np.array_equal(reference_result, test_result)
         if np.all(test):
-            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+            _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
             return (TestResult.PASS, 0.0, "")
         msg = "Boolean result does not match"
         tolerance = 0.0
@@ -177,18 +227,18 @@
     elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
         tolerance = float_tolerance
         if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
-            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+            _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
             return (TestResult.PASS, tolerance, "")
         msg = "Float result does not match within tolerance of {}".format(tolerance)
         difference = reference_result - test_result
         # Fall-through to below to add failure values
     else:
-        print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
+        _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
         msg = "Unsupported results type: {}".format(reference_result.dtype)
         return (TestResult.MISMATCH, 0.0, msg)
 
     # Fall-through for mismatch failure to add values to msg
-    print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+    _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
     np.set_printoptions(threshold=128, edgeitems=2)
 
     if difference is not None:
@@ -209,18 +259,65 @@
     """Check that the supplied reference and result files have the same contents."""
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "reference_path", type=Path, help="the path to the reference file to test"
+        "ref_result_path",
+        type=Path,
+        help="path to the reference model result file to check",
     )
     parser.add_argument(
-        "result_path", type=Path, help="the path to the result file to test"
+        "imp_result_path",
+        type=Path,
+        help="path to the implementation result file to check",
     )
     parser.add_argument(
         "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
     )
+    parser.add_argument(
+        "--test_path", type=Path, help="path to the test that produced the results"
+    )
+    parser.add_argument(
+        "--bnd-result-path",
+        type=Path,
+        help="path to the reference model bounds result file for the dot product compliance check",
+    )
+    parser.add_argument(
+        "--ofm-name",
+        type=str,
+        help="name of the output tensor to check, defaults to the first ofm_name listed in the test",
+    )
+    parser.add_argument(
+        "--verify-lib-path",
+        type=Path,
+        help="path to TOSA verify library",
+    )
     args = parser.parse_args(argv)
 
+    if args.test_path:
+        # Get details from the test path
+        test_desc_path = args.test_path / "desc.json"
+        if not args.test_path.is_dir() or not test_desc_path.is_file():
+            print(f"Invalid test directory {str(args.test_path)}")
+            return TestResult.MISSING_FILE
+
+        try:
+            with test_desc_path.open("r") as fd:
+                test_desc = json.load(fd)
+        except Exception as e:
+            print(f"Invalid test description file {str(test_desc_path)}: {e}")
+            return TestResult.INCORRECT_FORMAT
+        test_name = args.test_path.name
+    else:
+        test_desc = None
+        test_name = None
+
     result, tolerance, msg = test_check(
-        args.reference_path, args.result_path, float_tolerance=args.fp_tolerance
+        args.ref_result_path,
+        args.imp_result_path,
+        float_tolerance=args.fp_tolerance,
+        test_name=test_name,
+        test_desc=test_desc,
+        bnd_result_path=args.bnd_result_path,
+        ofm_name=args.ofm_name,
+        verify_lib_path=args.verify_lib_path,
     )
     if result != TestResult.PASS:
         print(msg)