Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF()

- Adding return code enum class: {VALID, UNPREDICTABLE, ERROR}
- Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes
  Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16())
- Update setExpectedFailure() to setExpectedReturnCode() on test generation script
- Update test regression script to interface with reference model change

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1
diff --git a/verif/tosa_ref_run.py b/verif/tosa_ref_run.py
index 098f39b..499513b 100644
--- a/verif/tosa_ref_run.py
+++ b/verif/tosa_ref_run.py
@@ -1,5 +1,3 @@
-import os
-
 # Copyright (c) 2020-2021, ARM Limited.
 #
 #    Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,9 +16,17 @@
 import json
 import shlex
 import subprocess
+from enum import Enum, IntEnum, unique
 from tosa_test_runner import TosaTestRunner, run_sh_command
 
 
+@unique
+class TosaReturnCode(IntEnum):
+    VALID = 0
+    UNPREDICTABLE = 1
+    ERROR = 2
+
+
 class TosaRefRunner(TosaTestRunner):
     def __init__(self, args, runnerArgs, testDir):
         super().__init__(args, runnerArgs, testDir)
@@ -41,18 +47,29 @@
         if args.ref_intermediates:
             ref_cmd.extend(["-Ddump_intermediates=1"])
 
-        expectedFailure = self.testDesc["expected_failure"]
+        expectedReturnCode = self.testDesc["expected_return_code"]
 
         try:
-            run_sh_command(self.args, ref_cmd)
-            if expectedFailure:
-                result = TosaTestRunner.Result.UNEXPECTED_PASS
+            rc = run_sh_command(self.args, ref_cmd)
+            if rc == TosaReturnCode.VALID:
+                if expectedReturnCode == TosaReturnCode.VALID:
+                    result = TosaTestRunner.Result.EXPECTED_PASS
+                else:
+                    result = TosaTestRunner.Result.UNEXPECTED_PASS
+            elif rc == TosaReturnCode.ERROR:
+                if expectedReturnCode == TosaReturnCode.ERROR:
+                    result = TosaTestRunner.Result.EXPECTED_FAILURE
+                else:
+                    result = TosaTestRunner.Result.UNEXPECTED_FAILURE
+            elif rc == TosaReturnCode.UNPREDICTABLE:
+                if expectedReturnCode == TosaReturnCode.UNPREDICTABLE:
+                    result = TosaTestRunner.Result.EXPECTED_FAILURE
+                else:
+                    result = TosaTestRunner.Result.UNEXPECTED_FAILURE
             else:
-                result = TosaTestRunner.Result.EXPECTED_PASS
+                raise Exception("Return code unknown.")
+
         except Exception as e:
-            if expectedFailure:
-                result = TosaTestRunner.Result.EXPECTED_FAILURE
-            else:
-                result = TosaTestRunner.Result.UNEXPECTED_FAILURE
+            raise Exception("Runtime Error when running: {}".format(" ".join(ref_cmd)))
 
         return result
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index b4daaad..35dd9a2 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -31,6 +31,7 @@
     ResizeMode,
     Version,
 )
+from tosa_ref_run import TosaReturnCode
 
 # Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
 parent_dir = os.path.dirname(os.path.realpath(__file__))
@@ -57,6 +58,7 @@
 
 ByteMask = np.uint64(0xFF)
 
+
 def dtype_str_to_val(name):
 
     for i in range(len(DTypeNames)):
@@ -428,10 +430,12 @@
                     u8_data.extend([b0, b1, b2, b3, b4, b5])
             elif self.dtype == DType.FLOAT:
                 for val in self.data:
-                    b = struct.pack('!f', val)
+                    b = struct.pack("!f", val)
                     u8_data.extend([b[3], b[2], b[1], b[0]])
             else:
-                raise Exception("unsupported data type {}".format(DTypeNames[self.dtype]))
+                raise Exception(
+                    "unsupported data type {}".format(DTypeNames[self.dtype])
+                )
             fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
 
         TosaTensor.TosaTensorStart(builder)
@@ -586,7 +590,7 @@
         self.currResultIdx = 0
 
         # Is this an illegal test that is expected to fail?
-        self.expectedFailure = False
+        self.expectedReturnCode = TosaReturnCode.VALID
         self.expectedFailureDesc = ""
 
     def __str__(self):
@@ -665,9 +669,9 @@
             op, inputs, outputs, attributes, quant_info
         )
 
-    def setExpectedFailure(self, desc="", val=True):
+    def setExpectedReturnCode(self, val, desc=""):
 
-        self.expectedFailure = val
+        self.expectedReturnCode = val
         self.expectedFailureDesc = desc
 
     def serialize(self):
@@ -719,7 +723,7 @@
         test_desc["ifm_file"] = ifm_file
         test_desc["ofm_name"] = ofm_name
         test_desc["ofm_file"] = ofm_file
-        test_desc["expected_failure"] = self.expectedFailure
+        test_desc["expected_return_code"] = self.expectedReturnCode
         if self.expectedFailureDesc:
             test_desc["expected_failure_desc"] = self.expectedFailureDesc
 
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index a3c6b05..efc819c 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -32,6 +32,7 @@
 import itertools
 
 from enum import IntEnum, Enum, unique
+from tosa_ref_run import TosaReturnCode
 
 # Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
 parent_dir = os.path.dirname(os.path.realpath(__file__))
@@ -65,8 +66,9 @@
     @staticmethod
     def qgUnary(testGen, op, dtype):
         qinfo = ts.TosaSerializerQuantInfo()
-        qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
-                TosaQuantGen.getQinfo(testGen, dtype))
+        qinfo.UnaryQuantInfo(
+            TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+        )
         return qinfo
 
     @staticmethod
@@ -86,8 +88,9 @@
     @staticmethod
     def qgMatmul(testGen, op, dtype):
         qinfo = ts.TosaSerializerQuantInfo()
-        qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
-                TosaQuantGen.getQinfo(testGen, dtype))
+        qinfo.MatMulQuantInfo(
+            TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+        )
         return qinfo
 
     @staticmethod
@@ -304,13 +307,11 @@
         assert rank == 2
 
         input_shape = testGen.makeShape(rank)
-        filter_oc = (
-            testGen.rng.integers(
-                low=testGen.args.tensor_shape_range[0],
-                high=testGen.args.tensor_shape_range[1],
-                size=1,
-            )[0]
-        )
+        filter_oc = testGen.rng.integers(
+            low=testGen.args.tensor_shape_range[0],
+            high=testGen.args.tensor_shape_range[1],
+            size=1,
+        )[0]
         filter_shape = np.asarray([filter_oc, input_shape[1]])
 
         bias_shape = np.asarray([filter_oc])
@@ -734,7 +735,10 @@
         random_permutations = testGen.rng.permutation(permutations)
 
         # Create list of required amount of permutations
-        arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
+        arg_list = [
+            ("perm{}".format(p), [random_permutations[p].tolist()])
+            for p in range(limit)
+        ]
         return arg_list
 
     @staticmethod
@@ -1154,7 +1158,7 @@
     def build_table(self, op, a):
         # Constant size depending on type, random values
         if a.dtype == DType.INT16:
-            table_dtype =  DType.INT16
+            table_dtype = DType.INT16
             table_arr = self.getRandTensor([513], table_dtype)
         else:
             assert a.dtype == DType.INT8
@@ -1497,7 +1501,7 @@
         if val.dtype == DType.INT8:
             input_zp = self.randInt(-128, 128)
             in_type_width = in_type_width + 1
-        elif  val.dtype == DType.UINT8:
+        elif val.dtype == DType.UINT8:
             input_zp = self.randInt(0, 256)
             in_type_width = in_type_width + 1
         else:
@@ -1536,7 +1540,9 @@
                 scale_arr[i], scale32
             )
             if shift_arr[i] < 2 or shift_arr[i] > 62:
-                self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
+                self.ser.setExpectedReturnCode(
+                    TosaReturnCode.UNPREDICTABLE, "OpRescale: invalid shift value"
+                )
 
         # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
 
@@ -1710,14 +1716,21 @@
             # Filter out the rank?
             if rankFilter is not None and r not in rankFilter:
                 continue
-            if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
+            if (
+                rankFilter is None
+                and shapeFilter[0] is None
+                and r not in default_test_rank_range
+            ):
                 continue
 
             for t in op["types"]:
 
                 # Filter tests based on dtype?
                 if dtypeFilter is not None:
-                    if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
+                    if not (
+                        t in dtypeFilter
+                        or (isinstance(t, list) and t[0] in dtypeFilter)
+                    ):
                         continue
 
                 # Create the placeholder and const tensors
@@ -2660,7 +2673,9 @@
             # Invalid test parameters?
             h = 0
             w = 0
-            ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
+            ser.setExpectedReturnCode(
+                TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
+            )
 
         ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
 
@@ -2700,7 +2715,9 @@
             # Invalid test parameters?
             h = 0
             w = 0
-            ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
+            ser.setExpectedReturnCode(
+                TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
+            )
 
         ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
 
@@ -2725,7 +2742,9 @@
             # Invalid test parameters?
             h = 0
             w = 0
-            ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
+            ser.setExpectedReturnCode(
+                TosaReturnCode.UNPREDICTABLE, "Invalid combination of pool2d parameters"
+            )
 
         ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
         return ser.addOutput(ofm_shape, ifm.dtype)
@@ -2889,39 +2908,59 @@
 
         if input_dtype == DType.FLOAT:
             if stride_fp[0] <= 0 or stride_fp[1] <= 0:
-                ser.setExpectedFailure(True, "Negative or zero stride")
+                ser.setExpectedReturnCode(
+                    TosaReturnCode.ERROR, "Negative or zero stride"
+                )
         else:
             if stride[0] <= 0 or stride[1] <= 0:
-                ser.setExpectedFailure(True, "Negative or zero stride")
+                ser.setExpectedReturnCode(
+                    TosaReturnCode.ERROR, "Negative or zero stride"
+                )
 
         if mode == ResizeMode.BILINEAR:
             if input_dtype == DType.INT8:
                 if output_dtype != DType.INT32:
-                    ser.setExpectedFailure(True, "Invalid output data type")
+                    ser.setExpectedReturnCode(
+                        TosaReturnCode.ERROR, "Invalid output data type"
+                    )
             elif input_dtype == DType.INT16:
                 if output_dtype != DType.INT48:
-                    ser.setExpectedFailure(true, "Invalid output data type")
+                    ser.setExpectedReturnCode(
+                        TosaReturnCode.ERROR, "Invalid output data type"
+                    )
             elif input_dtype == DType.FLOAT:
                 if output_dtype != DType.FLOAT:
-                    ser.setExpectedFailure(true, "Invalid output data type")
+                    ser.setExpectedReturnCode(
+                        TosaReturnCode.ERROR, "Invalid output data type"
+                    )
             else:
-                ser.setExpectedFailure(true, "Invalid input data type")
+                ser.setExpectedReturnCode(
+                    TosaReturnCode.ERROR, "Invalid input data type"
+                )
 
         elif mode == ResizeMode.NEAREST:
             if input_dtype == DType.INT8:
                 if output_dtype != DType.INT8:
-                    ser.setExpectedFailure(True, "Invalid output data type")
+                    ser.setExpectedReturnCode(
+                        TosaReturnCode.ERROR, "Invalid output data type"
+                    )
             elif input_dtype == DType.INT16:
                 if output_dtype != DType.INT16:
-                    ser.setExpectedFailure(true, "Invalid output data type")
+                    ser.setExpectedReturnCode(
+                        TosaReturnCode.ERROR, "Invalid output data type"
+                    )
             elif input_dtype == DType.FLOAT:
                 if output_dtype != DType.FLOAT:
-                    ser.setExpectedFailure(true, "Invalid output data type")
+                    ser.setExpectedReturnCode(
+                        TosaReturnCode.ERROR, "Invalid output data type"
+                    )
             else:
-                ser.setExpectedFailure(true, "Invalid input data type")
+                ser.setExpectedReturnCode(
+                    TosaReturnCode.ERROR, "Invalid input data type"
+                )
 
         else:
-            ser.setExpectedFailure(true, "Invalid resize mode")
+            ser.setExpectedReturnCode(TosaReturnCode.ERROR, "Invalid resize mode")
 
         return ser.addOutput(output_dims, output_dtype)
 
@@ -2941,6 +2980,8 @@
             raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
 
         if output_shape[1] <= 0 or output_shape[2] <= 0:
-            ser.setExpectedFailure(True, "Negative output shape")
+            ser.setExpectedReturnCode(
+                TosaReturnCode.UNPREDICTABLE, "Negative output shape"
+            )
 
         return ser.addOutput(output_shape, out_dtype)
diff --git a/verif/tosa_test_runner.py b/verif/tosa_test_runner.py
index 82d447e..e8f921d 100644
--- a/verif/tosa_test_runner.py
+++ b/verif/tosa_test_runner.py
@@ -42,8 +42,8 @@
         return (rc.stdout, rc.stderr)
     else:
         rc = subprocess.run(full_cmd)
-    if rc.returncode != 0:
-        raise Exception("Error running command: {}".format(" ".join(full_cmd_esc)))
+
+    return rc.returncode
 
 
 class TosaTestRunner: