Initial lazy data-gen and compliance test build support

Add initial support for compliance and lazy data-gen meta data
added to desc.json for MATMUL.

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I00c047814134a96d7c98d890e93b5884e25b8e64
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 3014c81..d15f785 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,8 +1,12 @@
 # Copyright (c) 2020-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
+import json
 import os
 from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
 
+import generator.tosa_utils as gtu
 import numpy as np
 import serializer.tosa_serializer as ts
 from generator.tosa_arg_gen import TosaArgGen
@@ -13,15 +17,15 @@
 from generator.tosa_error_if import TosaErrorIfArgGen
 from generator.tosa_error_if import TosaErrorValidator
 from generator.tosa_error_if import TosaInvalidValidator
-from generator.tosa_utils import DTYPE_ATTRIBUTES
-from generator.tosa_utils import get_rank_mismatch_shape
-from generator.tosa_utils import get_wrong_output_type
-from generator.tosa_utils import MAX_RESIZE_DIMENSION
-from generator.tosa_utils import usableDTypes
-from generator.tosa_utils import vect_f32_to_bf16
+from schemavalidation.schemavalidation import TestDescSchemaValidator
 from tosa.DType import DType
 from tosa.Op import Op
 
+TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
+// SPDX-License-Identifier: Apache-2.0
+// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
+"""
+
 
 class TosaTestGen:
     # Maximum rank of tensor supported by test generator.
@@ -31,6 +35,10 @@
     TOSA_8K_LEVEL_MAX_KERNEL = 8192
     TOSA_8K_LEVEL_MAX_STRIDE = 8192
 
+    # Main compliance dot product statistical test range
+    TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
+    TOSA_MI_DOT_PRODUCT_MIN = 1000
+
     def __init__(self, args):
         self.args = args
         self.basePath = args.output_dir
@@ -45,6 +53,8 @@
         # Work out floating point range
         self.random_fp_low = min(args.tensor_fp_value_range)
         self.random_fp_high = max(args.tensor_fp_value_range)
+        # JSON schema validation
+        self.descSchemaValidator = TestDescSchemaValidator()
 
     def createSerializer(self, opName, testPath):
         self.testPath = os.path.join(opName, testPath)
@@ -53,81 +63,131 @@
         os.makedirs(fullPath, exist_ok=True)
         # Embed const data in the flatbuffer
         constMode = ts.ConstMode.EMBED
-        if self.args.dump_consts:
+        if self.args.lazy_data_gen:
+            # Lazy data generation - so make constants files
+            constMode = ts.ConstMode.INPUTS
+        elif self.args.dump_consts:
             constMode = ts.ConstMode.EMBED_DUMP
         self.ser = ts.TosaSerializer(fullPath, constMode)
 
     def getSerializer(self):
         return self.ser
 
-    def serialize(self, testName):
-        with open(
-            os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
-        ) as fd:
+    def serialize(self, testName, metaData=None):
+        path = Path(self.basePath) / self.testPath
+
+        # Write out TOSA flatbuffer binary
+        path_fb = path / f"{testName}.tosa"
+        with path_fb.open("wb") as fd:
             fd.write(self.ser.serialize())
 
-        with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
-            fd.write(self.ser.writeJson("{}.tosa".format(testName)))
+        # Get JSON descriptor from serializer
+        desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
+
+        if metaData:
+            # Add extra meta data to desc.json
+            desc["meta"] = metaData
+
+        # Validate desc.json before we output it
+        self.descSchemaValidator.validate_config(desc)
+
+        if metaData:
+            if self.args.lazy_data_gen and "data_gen" in metaData:
+                # Output datagen meta data as CPP data
+                path_md = path / f"{testName}_meta_data_gen.cpp"
+                with path_md.open("w") as fd:
+                    fd.write(TOSA_AUTOGENERATED_HEADER)
+                    fd.write("// Test meta data for data generation setup\n\n")
+                    fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
+                    json.dump(metaData["data_gen"], fd)
+                    fd.write(')";\n\n')
+            if "compliance" in metaData:
+                # Output datagen meta data as CPP data
+                path_md = path / f"{testName}_meta_compliance.cpp"
+                with path_md.open("w") as fd:
+                    fd.write(TOSA_AUTOGENERATED_HEADER)
+                    fd.write("// Test meta data for compliance validation\n\n")
+                    fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
+                    json.dump(metaData["compliance"], fd)
+                    fd.write(')";\n\n')
+
+        # Write desc.json
+        path_desc = path / "desc.json"
+        with path_desc.open("w") as fd:
+            json.dump(desc, fd, indent=1)
 
     def resetRNG(self, seed=None):
         if seed is None:
             seed = self.random_seed + 1
         self.rng = np.random.default_rng(seed)
 
+    def getDTypeRange(self, dtype, high_inclusive=False):
+        # Returns dtype value range boundaries (low, high)
+        # The high boundary is excluded in the range
+        # unless high_inclusive is True
+
+        if dtype in (DType.FP32, DType.FP16, DType.BF16):
+            return (self.random_fp_low, self.random_fp_high)
+        elif dtype == DType.BOOL:
+            rng = (0, 2)
+        elif dtype == DType.UINT8:
+            rng = (0, 256)
+        elif dtype == DType.UINT16:
+            rng = (0, 65536)
+        elif dtype == DType.INT4:
+            # TOSA specific INT4 weight range from -7 to 7
+            rng = (-7, 8)
+        elif dtype == DType.INT8:
+            rng = (-128, 128)
+        elif dtype == DType.INT16:
+            rng = (-32768, 32768)
+        elif dtype in (DType.INT32, DType.SHAPE):
+            # restricting too large value for SHAPE
+            rng = (-(1 << 31), (1 << 31))
+        elif dtype == DType.INT48:
+            rng = (-(1 << 47), (1 << 47))
+        else:
+            raise Exception("Unknown dtype: {}".format(dtype))
+
+        if not high_inclusive:
+            # Exclusive high: low <= range < high
+            return rng
+        else:
+            # Inclusive range: low <= range <= high
+            return (rng[0], rng[1] - 1)
+
     def getRandTensor(self, shape, dtype):
+        low, high = self.getDTypeRange(dtype)
+
         if dtype == DType.BOOL:
             return np.bool_(self.rng.choice(a=[False, True], size=shape))
-        # TOSA specific INT4 weight range from -7 to 7
-        elif dtype == DType.INT4:
-            return np.int32(self.rng.integers(low=-7, high=8, size=shape))
-        elif dtype == DType.INT8:
-            return np.int32(self.rng.integers(low=-128, high=128, size=shape))
-        elif dtype == DType.UINT8:
-            return np.int32(self.rng.integers(low=0, high=256, size=shape))
-        elif dtype == DType.INT16:
-            return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
-        elif dtype == DType.UINT16:
-            return np.int32(self.rng.integers(low=0, high=65536, size=shape))
-        elif (
-            dtype == DType.INT32 or dtype == DType.SHAPE
-        ):  # restricting too large value for SHAPE
-            return np.int32(
-                self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
-            )
         elif dtype == DType.INT48:
-            return np.int64(
-                self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
-            )
-        elif dtype == DType.FP16:
-            return np.float16(
-                self.rng.uniform(
-                    low=self.random_fp_low, high=self.random_fp_high, size=shape
-                )
-            )
-        elif dtype == DType.BF16:
-            f32_tensor = np.float32(
-                self.rng.uniform(
-                    low=self.random_fp_low, high=self.random_fp_high, size=shape
-                )
-            )
-            # Floor the last 16 bits of each f32 value
-            return np.float32(vect_f32_to_bf16(f32_tensor))
-        elif dtype == DType.FP32:
-            return np.float32(
-                self.rng.uniform(
-                    low=self.random_fp_low, high=self.random_fp_high, size=shape
-                )
-            )
+            return np.int64(self.rng.integers(low=low, high=high, size=shape))
+        elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+            f_tensor = self.rng.uniform(low=low, high=high, size=shape)
+
+            if dtype == DType.FP16:
+                return np.float16(f_tensor)
+            else:
+                f32_tensor = np.float32(f_tensor)
+                if dtype == DType.BF16:
+                    # Floor the last 16 bits of each f32 value
+                    return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+                else:
+                    return f32_tensor
         else:
-            raise Exception("Unrecognized Dtype: {}".format(dtype))
+            # All other integer types
+            return np.int32(self.rng.integers(low=low, high=high, size=shape))
 
     def buildPlaceholderTensors(self, shape_list, dtype_list):
         placeholders = []
 
         assert len(shape_list) == len(dtype_list)
 
+        arr = None
         for idx, shape in enumerate(shape_list):
-            arr = self.getRandTensor(shape, dtype_list[idx])
+            if not self.args.lazy_data_gen:
+                arr = self.getRandTensor(shape, dtype_list[idx])
             placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
 
         return placeholders
@@ -137,8 +197,10 @@
 
         assert len(shape_list) == len(dtype_list)
 
+        arr = None
         for idx, shape in enumerate(shape_list):
-            arr = self.getRandTensor(shape, dtype_list[idx])
+            if not self.args.lazy_data_gen:
+                arr = self.getRandTensor(shape, dtype_list[idx])
             consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
 
         return consts
@@ -161,38 +223,20 @@
         return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
 
     def getRandNumberDType(self, dtype):
+        low, high = self.getDTypeRange(dtype)
+
         if dtype == DType.FP32:
-            return np.float32(
-                self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
-            )
+            return np.float32(self.rng.uniform(low=low, high=high))
         elif dtype == DType.FP16:
-            return np.float16(
-                self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
-            )
+            return np.float16(self.rng.uniform(low=low, high=high))
         elif dtype == DType.BF16:
-            rand_f32 = np.float32(
-                self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
-            )
-            return vect_f32_to_bf16(rand_f32)
+            rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+            return gtu.vect_f32_to_bf16(rand_f32)
         elif dtype == DType.BOOL:
             return self.rng.choice([False, True])
-        # TOSA specific INT4 weight range from -7 to 7
-        elif dtype == DType.INT4:
-            low, high = (-7, 8)
-        elif dtype == DType.INT8:
-            low, high = (-128, 128)
-        elif dtype == DType.INT16:
-            low, high = (-32768, 32768)
-        elif (
-            dtype == DType.INT32 or dtype == DType.SHAPE
-        ):  # restricting too large value for SHAPE
-            low, high = (-(1 << 31), (1 << 31))
         elif dtype == DType.INT48:
-            low, high = (-(1 << 47), (1 << 47))
             # Special size
             return np.int64(self.rng.integers(low, high, size=1))[0]
-        else:
-            raise Exception("Unknown dtype: {}".format(dtype))
 
         return np.int32(self.rng.integers(low, high, size=1))[0]
 
@@ -212,8 +256,8 @@
             # Limit types to the first 2 as the 3rd is the accumulator
             return "x".join(strs[:2])
         else:
-            if dtype in DTYPE_ATTRIBUTES:
-                return DTYPE_ATTRIBUTES[dtype]["str"]
+            if dtype in gtu.DTYPE_ATTRIBUTES:
+                return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
             else:
                 raise Exception(
                     "Unknown dtype, cannot convert to string: {}".format(dtype)
@@ -221,8 +265,8 @@
 
     def typeWidth(self, dtype):
         """Get the datatype width for data types"""
-        if dtype in DTYPE_ATTRIBUTES:
-            return DTYPE_ATTRIBUTES[dtype]["width"]
+        if dtype in gtu.DTYPE_ATTRIBUTES:
+            return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
         else:
             raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
 
@@ -237,11 +281,44 @@
             low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
         )
 
-    # Argument generators
-    # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
-    # Where the string descriptor is used to generate the test name and
-    # The build_fcn_arg_list is expanded and passed to the operator test
-    # build function
+    def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
+        if errorName:
+            # No compliance for error tests
+            return None
+        # Create compliance meta data for expected output tensor
+        compliance_tens = {"mode": None}
+        if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
+            mode = gtu.ComplianceMode.DOT_PRODUCT
+            compliance_tens["dot_product_info"] = {
+                "s": argsDict["s"],
+                "ks": argsDict["ks"],
+                "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
+            }
+        elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
+            mode = gtu.ComplianceMode.FP_SPECIAL
+        elif "compliance" in op and "ulp" in op["compliance"]:
+            mode = gtu.ComplianceMode.ULP
+            compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
+        elif op["op"] == Op.REDUCE_PRODUCT:
+            mode = gtu.ComplianceMode.REDUCE_PRODUCT
+        else:
+            mode = gtu.ComplianceMode.EXACT
+        compliance_tens["mode"] = gtu.ComplianceMode(mode).name
+
+        return compliance_tens
+
+    # Build Op functions
+    # Create the output tensor (calling OutputShaper as needed)
+    # Do final tweaks to attributes (if necessary for errorIf)
+    # Add Op into graph
+    # Return resulting tensor information or BuildInfo
+
+    class BuildInfo:
+        """Enhanced build information containing result tensor and associated compliance dict."""
+
+        def __init__(self, resultTensor, complianceDict):
+            self.resultTensor = resultTensor
+            self.complianceDict = complianceDict
 
     def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
         result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
@@ -975,15 +1052,16 @@
         return result_tens
 
     def build_matmul(
-        self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
+        self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None
     ):
-        result_tens = OutputShaper.matmulOp(
+        accum_dtype = args_dict["acc_type"]
+        result_tensor = OutputShaper.matmulOp(
             self.ser, self.rng, a, b, accum_dtype, error_name
         )
 
         # Invalidate Input/Output list for error if checks.
         input_list = [a.name, b.name]
-        output_list = [result_tens.name]
+        output_list = [result_tensor.name]
         pCount, cCount = op["operands"]
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -999,10 +1077,10 @@
             input_dtype=a.dtype,
             input2_shape=b.shape,
             input2_dtype=b.dtype,
-            output_shape=result_tens.shape,
-            output_dtype=result_tens.dtype,
+            output_shape=result_tensor.shape,
+            output_dtype=result_tensor.dtype,
             qinfo=qinfo,
-            result_tensors=[result_tens],
+            result_tensors=[result_tensor],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1014,7 +1092,12 @@
         attr.MatMulAttribute(qinfo[0], qinfo[1])
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
-        return result_tens
+
+        compliance = self.tensorComplianceMetaData(
+            op, args_dict, result_tensor, error_name
+        )
+
+        return TosaTestGen.BuildInfo(result_tensor, compliance)
 
     def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
         result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
@@ -1895,7 +1978,7 @@
 
     def _get_condition_tensor(self, op, cond, error_name):
         if error_name == ErrorIf.CondIfCondNotMatchingBool:
-            cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
+            cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
         else:
             cond_type = DType.BOOL
         if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
@@ -2357,7 +2440,7 @@
         # Initialize a new random number generator
         self.rng = np.random.default_rng(self.random_seed)
 
-        build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
+        _, tgen_fcn, _, agen_fcn = op["build_fcn"]
 
         # Test list consists of a tuple of:
         # (opName, testNameStr, dtype, shapeList, argumentsList)
@@ -2461,7 +2544,7 @@
         # Create a serializer
         self.createSerializer(opName, testStr)
 
-        build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
+        build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
         if "error_if_validators" in op:
             error_if_validators = op["error_if_validators"]
         else:
@@ -2495,24 +2578,37 @@
             qgen = None
 
         # Build the random tensor operands and the test
-        tens = []
 
         if qgen is not None:
             qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
         else:
             qinfo = None
 
-        tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
+        # Extra meta data for the desc.json
+        tensMeta = {}
+
+        # Check we are using the new testArgs interface with an argsDict dictionary
+        if len(testArgs) == 1 and isinstance(testArgs[0], dict):
+            argsDict = testArgs[0]
+            assert "dg_type" in argsDict
+            tvgInfo = tvgen_fcn(
+                self, opName, dtypeList, shapeList, argsDict, error_name
+            )
+            if tvgInfo.dataGenDict:
+                tensMeta["data_gen"] = tvgInfo.dataGenDict
+            tens = tvgInfo.tensorList
+        else:
+            tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
 
         try:
             if error_if_validators is None:
                 if qinfo is not None:
-                    resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
+                    result = build_fcn(self, op, *tens, *testArgs, qinfo)
                 else:
-                    resultName = build_fcn(self, op, *tens, *testArgs)
+                    result = build_fcn(self, op, *tens, *testArgs)
             else:
                 if qinfo is not None:
-                    resultName = build_fcn(
+                    result = build_fcn(
                         self,
                         op,
                         *tens,
@@ -2522,7 +2618,7 @@
                         qinfo=qinfo,
                     )
                 else:
-                    resultName = build_fcn(
+                    result = build_fcn(
                         self,
                         op,
                         *tens,
@@ -2534,9 +2630,16 @@
             print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
             raise e
 
-        if resultName:
+        if result:
             # The test is valid, serialize it
-            self.serialize("test")
+            if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
+                # Add the compliance meta data
+                # NOTE: This currently expects only one result output
+                tensMeta["compliance"] = {
+                    "version": "0.1",
+                    "tensors": {result.resultTensor.name: result.complianceDict},
+                }
+            self.serialize("test", tensMeta)
         else:
             # The test is not valid
             print(f"Invalid ERROR_IF test created: {opName} {testStr}")
@@ -2865,7 +2968,7 @@
             "build_fcn": (
                 build_matmul,
                 TosaTensorGen.tgMatmul,
-                TosaTensorValuesGen.tvgDefault,
+                TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agMatMul,
             ),
             "qgen": TosaQuantGen.qgMatmul,
@@ -2878,6 +2981,10 @@
                 TosaErrorValidator.evWrongInputList,
                 TosaErrorValidator.evWrongOutputList,
             ),
+            "data_gen": {
+                "fp": (gtu.DataGenType.DOT_PRODUCT,),
+                "int": (gtu.DataGenType.PSEUDO_RANDOM,),
+            },
         },
         "max_pool2d": {
             "op": Op.MAX_POOL2D,
@@ -4446,7 +4553,7 @@
                 excludes = [DType.FP16, DType.FP32]
             else:
                 excludes = [out_dtype]
-            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(ofm_shape, out_dtype)
@@ -4508,7 +4615,7 @@
                 excludes = [DType.FP16, DType.FP32]
             else:
                 excludes = [out_dtype]
-            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(ofm_shape, out_dtype)
@@ -4559,7 +4666,7 @@
                 excludes = [DType.FP16, DType.FP32]
             else:
                 excludes = [out_dtype]
-            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(ofm_shape, out_dtype)
@@ -4711,7 +4818,7 @@
             bad_dim = rng.choice(range(len(output_shape)))
             output_shape[bad_dim] -= rng.choice([1, 2])
         elif error_name == ErrorIf.RankMismatch:
-            output_shape = get_rank_mismatch_shape(rng, output_shape)
+            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
 
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [
@@ -4806,7 +4913,7 @@
         elif error_name == ErrorIf.InputSizeStartLengthMismatch:
             output_shape = input.shape.copy()
         elif error_name == ErrorIf.RankMismatch:
-            output_shape = get_rank_mismatch_shape(rng, output_shape)
+            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
 
         return ser.addOutput(output_shape, outputDType)
 
@@ -4820,7 +4927,7 @@
             output_shape[i] = a.shape[i] * multiples[i]
 
         if error_name == ErrorIf.RankMismatch:
-            output_shape = get_rank_mismatch_shape(rng, output_shape)
+            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
 
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [
@@ -4853,7 +4960,7 @@
             for i in range(len(output_shape)):
                 output_shape[i] += rng.integers(1, 10)
         elif error_name == ErrorIf.RankMismatch:
-            output_shape = get_rank_mismatch_shape(rng, output_shape)
+            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
 
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [
@@ -4980,21 +5087,21 @@
             oh = max(oh, 1)
             ow = max(ow, 1)
             if error_name != ErrorIf.MaxDimExceeded:
-                oh = min(oh, MAX_RESIZE_DIMENSION - 1)
-                ow = min(ow, MAX_RESIZE_DIMENSION - 1)
+                oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
+                ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
 
         if error_name == ErrorIf.ResizeOutputShapeMismatch:
             choices = [1, 2, 3]
             change = rng.choice(choices)
             # increment in multiples of scale_y/x_d so we don't hit non-integer error case
             if change in [1, 3]:
-                if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
+                if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
                     oh -= scale_y_d
                     assert oh > 0  # Should have been caught in agResize
                 else:
                     oh += scale_y_d
             if change in [2, 3]:
-                if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
+                if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
                     ow -= scale_x_d
                     assert ow > 0  # Should have been caught in agResize
                 else:
@@ -5051,7 +5158,7 @@
                 excludes = [DType.FP16, DType.FP32]
             else:
                 excludes = [out_dtype]
-            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(output_shape, out_dtype)
@@ -5075,7 +5182,7 @@
 
         if error_name == ErrorIf.WrongOutputType:
             excludes = [DType.FP32]
-            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
             output_dtype = rng.choice(wrong_dtypes)
         elif error_name == ErrorIf.BatchMismatch:
             output_shape[0] += rng.integers(1, 10)
@@ -5100,7 +5207,7 @@
         output_dtype = value.dtype
         if error_name == ErrorIf.WrongOutputType:
             excludes = [DType.FP32]
-            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
             output_dtype = rng.choice(wrong_dtypes)
         elif error_name == ErrorIf.BatchMismatch:
             output_shape[0] += rng.integers(1, 10)