# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
import struct
from copy import deepcopy
from datetime import datetime
from pathlib import Path

import generator.tosa_utils as gtu
import numpy as np
import serializer.tosa_serializer as ts
from generator.datagenerator import GenerateLibrary
from generator.tosa_arg_gen import TosaArgGen
from generator.tosa_arg_gen import TosaQuantGen
from generator.tosa_arg_gen import TosaTensorGen
from generator.tosa_arg_gen import TosaTensorValuesGen
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from schemavalidation.schemavalidation import TestDescSchemaValidator
from tosa.DType import DType
from tosa.Op import Op

TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
// SPDX-License-Identifier: Apache-2.0
// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
"""

logging.basicConfig()
logger = logging.getLogger("tosa_verif_build_tests")


class TosaTestGen:
    # Maximum rank of tensor supported by test generator.
    # This currently matches the 8K level defined in the specification.
    TOSA_TENSOR_MAX_RANK = 6
    TOSA_8K_LEVEL_MAX_SCALE = 64
    TOSA_8K_LEVEL_MAX_KERNEL = 8192
    TOSA_8K_LEVEL_MAX_STRIDE = 8192

    # Main compliance dot product statistical test range
    TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
    TOSA_MI_DOT_PRODUCT_MIN = 1000

    def __init__(self, args):
        self.args = args
        self.basePath = args.output_dir
        self.random_seed = args.random_seed
        self.ser = None
        self.rng = np.random.default_rng(self.random_seed)
        self.createDynamicOpLists()
        self.initOpListDefaults()
        self.quantGen = TosaQuantGen()
        # Force makeShape to do a specific starting shape
        self.targetted_shape = None
        # JSON schema validation
        self.descSchemaValidator = TestDescSchemaValidator()
        # Data generator library is sometimes needed for compliance set up
        # even if we are generating the data later (lazy_data_generation)
        self.dgl = GenerateLibrary(args.generate_lib_path)

        # Work out floating point range
        def convertFPRange(rangeFP, maxFP):
            # Converts program arguments of max/-max to FP max
            vals = []
            for v in rangeFP:
                if v == "max":
                    v = maxFP
                elif v == "-max":
                    v = -maxFP
                elif v < 0:
                    # Trim to minimum data type value
                    v = max(v, -maxFP)
                elif v > 0:
                    # Trim to maximum data type value
                    v = min(v, maxFP)
                vals.append(v)
            return tuple(sorted(vals))

        self.random_float_range = {}
        for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
            self.random_float_range[dtype] = convertFPRange(
                args.tensor_fp_value_range,
                TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
            )

    def createSerializer(self, opName, testPath):
        self.testPath = os.path.join(opName, testPath)

        fullPath = os.path.join(self.basePath, self.testPath)
        os.makedirs(fullPath, exist_ok=True)
        # Embed const data in the flatbuffer
        constMode = ts.ConstMode.EMBED
        if self.args.lazy_data_gen:
            # Lazy data generation - so make constants files
            constMode = ts.ConstMode.INPUTS
        elif self.args.dump_consts:
            constMode = ts.ConstMode.EMBED_DUMP
        self.ser = ts.TosaSerializer(fullPath, constMode)

    def getSerializer(self):
        return self.ser

    def serialize(self, testName, metaData=None):
        path = Path(self.basePath) / self.testPath

        # Write out TOSA flatbuffer binary
        path_fb = path / f"{testName}.tosa"
        with path_fb.open("wb") as fd:
            fd.write(self.ser.serialize())

        # Get JSON descriptor from serializer
        desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))

        if metaData:
            # Add extra meta data to desc.json
            desc["meta"] = metaData

        # Validate desc.json before we output it
        self.descSchemaValidator.validate_config(desc)

        if metaData:
            if "data_gen" in metaData:
                if self.args.lazy_data_gen:
                    # Output datagen meta data as CPP data
                    path_md = path / f"{testName}_meta_data_gen.cpp"
                    with path_md.open("w") as fd:
                        fd.write(TOSA_AUTOGENERATED_HEADER)
                        fd.write("// Test meta data for data generation setup\n\n")
                        fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
                        json.dump(metaData["data_gen"], fd)
                        fd.write(')";\n\n')
            if "compliance" in metaData:
                # Output datagen meta data as CPP data
                path_md = path / f"{testName}_meta_compliance.cpp"
                with path_md.open("w") as fd:
                    fd.write(TOSA_AUTOGENERATED_HEADER)
                    fd.write("// Test meta data for compliance validation\n\n")
                    fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
                    json.dump(metaData["compliance"], fd)
                    fd.write(')";\n\n')

        # Write desc.json
        path_desc = path / "desc.json"
        with path_desc.open("w") as fd:
            json.dump(desc, fd, indent=1)

    def resetRNG(self, seed=None):
        if seed is None:
            seed = self.random_seed + 1
        self.rng = np.random.default_rng(seed)

    def getDTypeRange(self, dtype, high_inclusive=False):
        # Returns dtype value range boundaries (low, high)
        # The high boundary is excluded in the range
        # unless high_inclusive is True
        if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
            return self.random_float_range[dtype]
        elif dtype == DType.BOOL:
            rng = (0, 2)
        elif dtype == DType.UINT8:
            rng = (0, 256)
        elif dtype == DType.UINT16:
            rng = (0, 65536)
        elif dtype == DType.INT4:
            # TOSA specific INT4 weight range from -7 to 7
            rng = (-7, 8)
        elif dtype == DType.INT8:
            rng = (-128, 128)
        elif dtype == DType.INT16:
            rng = (-32768, 32768)
        elif dtype == DType.INT32:
            rng = (-(1 << 31), (1 << 31))
        elif dtype == DType.SHAPE:
            rng = tuple(self.args.tensor_shape_range[0:2])
        elif dtype == DType.INT48:
            rng = (-(1 << 47), (1 << 47))
        else:
            raise Exception("Unknown dtype: {}".format(dtype))

        if not high_inclusive:
            # Exclusive high: low <= range < high
            return rng
        else:
            # Inclusive range: low <= range <= high
            return (rng[0], rng[1] - 1)

    def getRandTensor(self, shape, dtype, data_range=None):
        if data_range is None:
            low, high = self.getDTypeRange(dtype)
        else:
            low, high = data_range

        if dtype == DType.BOOL:
            return np.bool_(self.rng.choice(a=[False, True], size=shape))
        elif dtype == DType.INT4:
            return np.int8(self.rng.integers(low=low, high=high, size=shape))
        elif dtype == DType.INT8:
            return np.int8(self.rng.integers(low=low, high=high, size=shape))
        elif dtype == DType.UINT8:
            return np.uint8(self.rng.integers(low=low, high=high, size=shape))
        elif dtype == DType.INT16:
            return np.int16(self.rng.integers(low=low, high=high, size=shape))
        elif dtype == DType.UINT16:
            return np.uint16(self.rng.integers(low=low, high=high, size=shape))
        elif dtype in (DType.INT48, DType.SHAPE):
            return np.int64(self.rng.integers(low=low, high=high, size=shape))
        elif dtype in (
            DType.FP16,
            DType.BF16,
            DType.FP32,
            DType.FP8E4M3,
            DType.FP8E5M2,
        ):
            f_tensor = self.rng.uniform(low=low, high=high, size=shape)

            if dtype == DType.FP16:
                return np.float16(f_tensor)
            else:
                f32_tensor = np.float32(f_tensor)
                if dtype == DType.BF16:
                    # Floor the last 16 bits of each f32 value
                    return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
                elif dtype == DType.FP8E4M3:
                    return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
                elif dtype == DType.FP8E5M2:
                    return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
                else:
                    return f32_tensor
        else:
            # All other integer types
            return np.int32(self.rng.integers(low=low, high=high, size=shape))

    def buildPlaceholderTensors(self, shape_list, dtype_list):
        placeholders = []

        assert len(shape_list) == len(dtype_list)

        arr = None
        for idx, shape in enumerate(shape_list):
            if not self.args.lazy_data_gen:
                arr = self.getRandTensor(shape, dtype_list[idx])
            placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))

        return placeholders

    def buildConstTensors(self, shape_list, dtype_list):
        consts = []

        assert len(shape_list) == len(dtype_list)

        arr = None
        for idx, shape in enumerate(shape_list):
            if not self.args.lazy_data_gen:
                arr = self.getRandTensor(shape, dtype_list[idx])
            consts.append(self.ser.addConst(shape, dtype_list[idx], arr))

        return consts

    def makeShape(self, rank):
        if self.targetted_shape:
            return np.int32(self.targetted_shape)
        return np.int32(
            self.rng.integers(
                low=self.args.tensor_shape_range[0],
                high=self.args.tensor_shape_range[1],
                size=rank,
            )
        )

    def setTargetShape(self, shape):
        self.targetted_shape = shape

    def randInt(self, low=0, high=256):
        return np.int32(self.rng.integers(low=low, high=high, size=1))[0]

    def getRandNumberDType(self, dtype):
        low, high = self.getDTypeRange(dtype)

        if dtype == DType.FP32:
            return np.float32(self.rng.uniform(low=low, high=high))
        elif dtype == DType.FP16:
            return np.float16(self.rng.uniform(low=low, high=high))
        elif dtype == DType.BF16:
            rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
            return gtu.vect_f32_to_bf16(rand_f32)
        elif dtype == DType.FP8E4M3:
            rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
            return gtu.vect_f32_to_fp8e4m3(rand_f32)
        elif dtype == DType.FP8E5M2:
            rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
            return gtu.vect_f32_to_fp8e5m2(rand_f32)
        elif dtype == DType.BOOL:
            return self.rng.choice([False, True])
        elif dtype == DType.INT48 or dtype == DType.SHAPE:
            # Special size
            return np.int64(self.rng.integers(low, high, size=1))[0]

        return np.int32(self.rng.integers(low, high, size=1))[0]

    def shapeStr(self, shape):

        sStr = []
        # Convert to strings
        for i in shape:
            sStr.append(str(i))

        return "x".join(sStr)

    def typeStr(self, dtype):
        if isinstance(dtype, list) or isinstance(dtype, tuple):
            assert len(dtype) >= 2
            strs = [self.typeStr(t) for t in dtype]
            # Limit types to the first 2 as the 3rd is the accumulator
            return "x".join(strs[:2])
        else:
            if dtype in gtu.DTYPE_ATTRIBUTES:
                return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
            else:
                raise Exception(
                    "Unknown dtype, cannot convert to string: {}".format(dtype)
                )

    def constrictBatchSize(self, shape):
        # Limit the batch size unless an explicit target shape set
        if self.args.max_batch_size and not self.args.target_shapes:
            shape[0] = min(shape[0], self.args.max_batch_size)
        return shape

    def makeDimension(self):
        return self.randInt(
            low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
        )

    def tensorComplianceMetaData(
        self, op, inputType, argsDict, outputTensor, errorName
    ):
        # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
        UNSUPPORTED_NON_FP32_INPUT_OPS = (
            Op.MATMUL,
            Op.CONV2D,
            Op.FULLY_CONNECTED,
            Op.DEPTHWISE_CONV2D,
            Op.TRANSPOSE_CONV2D,
            Op.CONV3D,
        )
        if (
            errorName
            or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
            or (
                not gtu.dtypeIsSupportedByCompliance(inputType)
                and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
            )
        ):
            # No compliance for error tests or unsupported types currently
            return None

        # Create compliance meta data for expected output tensor
        compliance_tens = {
            "mode": None,
            # Data type is needed for all FP runs, as refmodel precise mode produces FP64
            "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
        }
        if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
            mode = gtu.ComplianceMode.DOT_PRODUCT
            compliance_tens["dot_product_info"] = {
                "s": argsDict["s"],
                "ks": int(argsDict["ksb"])
                if "ksb" in argsDict
                else int(argsDict["ks"]),
            }
        elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
            mode = gtu.ComplianceMode.FP_SPECIAL
        elif "compliance" in op and "ulp" in op["compliance"]:
            mode = gtu.ComplianceMode.ULP
            compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
        elif "compliance" in op and "relative" in op["compliance"]:
            mode = gtu.ComplianceMode.RELATIVE
            compliance_tens["relative_info"] = {
                "max": argsDict["max_abs_value"],
                "scale": op["compliance"]["relative"],
            }
        elif op["op"] == Op.REDUCE_PRODUCT:
            mode = gtu.ComplianceMode.REDUCE_PRODUCT
            compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
        elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
            mode = gtu.ComplianceMode.ABS_ERROR
            if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
                compliance_tens["abs_error_info"] = {
                    "lower_bound": op["compliance"]["abs_error_lower_bound"]
                }
        elif op["op"] in (Op.SIN, Op.COS):
            mode = gtu.ComplianceMode.ABS_ERROR
            if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
                compliance_tens["abs_error_info"] = {
                    "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
                }
        else:
            mode = gtu.ComplianceMode.EXACT
        compliance_tens["mode"] = gtu.ComplianceMode(mode).name

        return compliance_tens

    # Build Op functions
    # Create the output tensor (calling OutputShaper as needed)
    # Do final tweaks to attributes (if necessary for errorIf)
    # Add Op into graph
    # Return resulting tensor information or BuildInfo

    class BuildInfo:
        """Enhanced build information containing result tensor and associated compliance dict."""

        def __init__(self, resultTensor, complianceDict):
            if isinstance(resultTensor, list):
                assert complianceDict is None or isinstance(complianceDict, list)
                self.resultTensorList = resultTensor
                self.complianceDictList = complianceDict
            else:
                self.resultTensorList = [resultTensor]
                if complianceDict is None:
                    self.complianceDictList = None
                else:
                    self.complianceDictList = [complianceDict]

        def getComplianceInfo(self):
            if self.complianceDictList is None:
                return None
            else:
                tens_dict = {}
                for tens, comp in zip(self.resultTensorList, self.complianceDictList):
                    if comp is not None:
                        tens_dict[tens.name] = comp

                if tens_dict:
                    # Have some compliance data, so return the info
                    compliance = {
                        "version": "0.1",
                        "tensors": tens_dict,
                    }
                else:
                    compliance = None
                return compliance

    def build_unary(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]
        result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)

        assert not isinstance(op, int)

        # Ensure new output type has correct qinfo
        if error_name == ErrorIf.WrongOutputType:
            if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
                qinfo = [
                    TosaQuantGen.getZeroPoint(self, a.dtype),
                    TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
                ]

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            qinfo=qinfo,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        attr = None
        if op["op"] == Op.NEGATE:
            attr = ts.TosaSerializerAttribute()
            attr.NegateAttribute(qinfo[0], qinfo[1])

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )
        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_binary_broadcast(
        self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        a, b = inputs
        result_tensor = OutputShaper.binaryBroadcastOp(
            self.ser, self.rng, a, b, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, b.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input1=a,
            input2=b,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
        result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
        self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
        return result_tens

    def build_arithmetic_right_shift(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        a, b = inputs
        round = args_dict["round"]
        result_tensor = OutputShaper.binaryBroadcastOp(
            self.ser, self.rng, a, b, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, b.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input1=a,
            input2=b,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.ArithmeticRightShiftAttribute(round)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_mul(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        # Note that mul is binary operator but it has a shift value tensor
        assert len(inputs) == 3
        a, b, s = inputs

        result_tensor = OutputShaper.binaryBroadcastOp(
            self.ser, self.rng, a, b, error_name
        )

        # Special for multiply: Force the result to INT32 for INT types
        if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
            result_tensor.setDtype(DType.INT32)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
            outputDType = self.rng.choice(all_dtypes)
            result_tensor.setDtype(outputDType)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, b.name, s.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input1=a,
            input2=b,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_table(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]
        table = args_dict["table"]
        result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)

        attr = ts.TosaSerializerAttribute()
        attr.TableAttribute(table)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_select(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 3
        cond, a, b = inputs

        result_tensor = OutputShaper.selectOp(
            self.ser, self.rng, cond, a, b, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [cond.name, a.name, b.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input1=cond,
            input2=a,
            input3=b,
            input_shape=a.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(
            op["op"],
            input_list,
            output_list,
        )
        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_comparison(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        a, b = inputs

        result_tensor = OutputShaper.binaryComparisonOp(
            self.ser, self.rng, a, b, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, b.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input1=a,
            input2=b,
            input_shape=a.shape,
            input_dtype=a.dtype,
            output_shape=result_tensor.shape,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(
            op["op"],
            input_list,
            output_list,
        )

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )
        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_argmax(
        self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]
        axis = args_dict["axis"]
        result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            axis=axis,
            input_shape=a.shape,
            input_dtype=a.dtype,
            output_shape=result_tensor.shape,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.AxisAttribute(axis)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, inputs[0].dtype, args_dict, result_tensor, error_name
        )
        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_pool2d(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 1
        input = inputs[0]
        # max_pool has no accum_dtype
        accum_dtype = (
            args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
        )
        stride = args_dict["stride"]
        pad = args_dict["pad"]
        kernel = args_dict["kernel"]

        result_tensor = OutputShaper.pool2dOp(
            self.ser, self.rng, input, kernel, stride, pad, error_name
        )

        # Ensure new output type has correct qinfo
        if error_name == ErrorIf.WrongInputType:
            if input.dtype not in [DType.INT8, DType.UINT8]:
                qinfo = [
                    TosaQuantGen.getZeroPoint(self, input.dtype),
                    TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
                ]

        # Invalidate Input/Output list for error if checks.
        input_list = [input.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=input.shape,
            input_dtype=input.dtype,
            output_shape=result_tensor.shape,
            output_dtype=result_tensor.dtype,
            accum_dtype=accum_dtype,
            kernel=kernel,
            stride=stride,
            pad=pad,
            qinfo=qinfo,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        if qinfo is None:
            qinfo = [0, 0]

        attr = ts.TosaSerializerAttribute()
        attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, inputs[0].dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_conv2d(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 3
        ifm, filter, bias = inputs
        accum_dtype = args_dict["acc_type"]
        strides = args_dict["stride"]
        padding = args_dict["pad"]
        dilations = args_dict["dilation"]

        assert len(padding) == 4
        result_tensor = OutputShaper.conv2dOp(
            self.ser,
            self.rng,
            ifm,
            filter,
            accum_dtype,
            strides,
            padding,
            dilations,
            error_name,
        )

        # Ensure new output type has correct qinfo
        if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
            DType.INT8,
            DType.UINT8,
        ):
            qinfo = [
                TosaQuantGen.getZeroPoint(self, ifm.dtype),
                TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
            ]

        # Invalidate Input/Output list for error_if checks.
        input_list = [ifm.name, filter.name, bias.name]
        output_list = [result_tensor.name]
        num_operands = sum(op["operands"])
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_dtype=ifm.dtype,
            weight_dtype=filter.dtype,
            output_dtype=result_tensor.dtype,
            qinfo=qinfo,
            input_list=input_list,
            num_operands=num_operands,
            output_list=output_list,
            pad=padding,
            stride=strides,
            dilation=dilations,
            input_shape=ifm.shape,
            weight_shape=filter.shape,
            output_shape=result_tensor.shape,
        ):
            return None

        # TODO - Test local_bound, for now set local bound attribute to False
        local_bound = False

        attr = ts.TosaSerializerAttribute()
        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, ifm.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_conv3d(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 3
        ifm, filter, bias = inputs
        accum_dtype = args_dict["acc_type"]
        strides = args_dict["stride"]
        padding = args_dict["pad"]
        dilations = args_dict["dilation"]

        assert len(padding) == 6
        result_tensor = OutputShaper.conv3dOp(
            self.ser,
            self.rng,
            ifm,
            filter,
            accum_dtype,
            strides,
            padding,
            dilations,
            error_name,
        )

        # Ensure new output type has correct qinfo
        if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
            DType.INT8,
            DType.UINT8,
        ):
            qinfo = [
                TosaQuantGen.getZeroPoint(self, ifm.dtype),
                TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
            ]

        # Invalidate Input/Output list for error_if checks.
        input_list = [ifm.name, filter.name, bias.name]
        output_list = [result_tensor.name]
        num_operands = sum(op["operands"])
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_dtype=ifm.dtype,
            weight_dtype=filter.dtype,
            output_dtype=result_tensor.dtype,
            qinfo=qinfo,
            input_list=input_list,
            num_operands=num_operands,
            output_list=output_list,
            pad=padding,
            stride=strides,
            dilation=dilations,
            input_shape=ifm.shape,
            weight_shape=filter.shape,
            output_shape=result_tensor.shape,
        ):
            return None

        # TODO - Test local_bound, for now set local bound attribute to False
        local_bound = False

        attr = ts.TosaSerializerAttribute()
        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, ifm.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_transpose_conv2d(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 3
        ifm, filter, bias = inputs
        accum_dtype = args_dict["acc_type"]
        strides = args_dict["stride"]
        out_pad = args_dict["pad"]
        output_shape = args_dict["out_shape"]

        assert len(out_pad) == 4
        result_tensor = OutputShaper.transposeConv2DOp(
            self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
        )

        # Ensure new output type has correct qinfo
        if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
            DType.INT8,
            DType.UINT8,
        ):
            qinfo = [
                TosaQuantGen.getZeroPoint(self, ifm.dtype),
                TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
            ]

        # Invalidate Input/Output list for error_if checks.
        input_list = [ifm.name, filter.name, bias.name]
        output_list = [result_tensor.name]
        num_operands = sum(op["operands"])
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_dtype=ifm.dtype,
            weight_dtype=filter.dtype,
            output_dtype=result_tensor.dtype,
            qinfo=qinfo,
            input_list=input_list,
            num_operands=num_operands,
            output_list=output_list,
            pad=out_pad,
            stride=strides,
            input_shape=ifm.shape,
            weight_shape=filter.shape,
            output_shape=result_tensor.shape,
        ):
            return None

        # TODO - Test local_bound, for now set local bound attribute to False
        local_bound = False

        attr = ts.TosaSerializerAttribute()
        attr.TransposeConvAttribute(
            out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
        )

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, ifm.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_depthwise_conv2d(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 3
        ifm, filter, bias = inputs
        accum_dtype = args_dict["acc_type"]
        strides = args_dict["stride"]
        padding = args_dict["pad"]
        dilations = args_dict["dilation"]

        result_tensor = OutputShaper.depthwiseConv2dOp(
            self.ser,
            self.rng,
            ifm,
            filter,
            accum_dtype,
            strides,
            padding,
            dilations,
            error_name,
        )

        # Ensure new output type has correct qinfo
        if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
            DType.INT8,
            DType.UINT8,
        ):
            qinfo = [
                TosaQuantGen.getZeroPoint(self, ifm.dtype),
                TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
            ]

        # Invalidate Input/Output list for error_if checks.
        input_list = [ifm.name, filter.name, bias.name]
        output_list = [result_tensor.name]
        num_operands = sum(op["operands"])
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_dtype=ifm.dtype,
            weight_dtype=filter.dtype,
            output_dtype=result_tensor.dtype,
            qinfo=qinfo,
            input_list=input_list,
            num_operands=num_operands,
            output_list=output_list,
            pad=padding,
            stride=strides,
            dilation=dilations,
            input_shape=ifm.shape,
            weight_shape=filter.shape,
            output_shape=result_tensor.shape,
        ):
            return None

        # TODO - Test local_bound, for now set local bound attribute to False
        local_bound = False

        attr = ts.TosaSerializerAttribute()
        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, ifm.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_fully_connected(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 3
        ifm, filter, bias = inputs
        accum_dtype = args_dict["acc_type"]

        result_tensor = OutputShaper.fullyConnectedOp(
            self.ser, self.rng, ifm, filter, accum_dtype, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [ifm.name, filter.name, bias.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=ifm.shape,
            input_dtype=ifm.dtype,
            weight_dtype=filter.dtype,
            output_shape=result_tensor.shape,
            output_dtype=result_tensor.dtype,
            qinfo=qinfo,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
            accum_dtype=accum_dtype,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.FullyConnectedAttribute(qinfo[0], qinfo[1])

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, ifm.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_matmul(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        a, b = inputs
        accum_dtype = args_dict["acc_type"]
        result_tensor = OutputShaper.matmulOp(
            self.ser, self.rng, a, b, accum_dtype, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, b.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            input_dtype=a.dtype,
            input2_shape=b.shape,
            input2_dtype=b.dtype,
            output_shape=result_tensor.shape,
            output_dtype=result_tensor.dtype,
            qinfo=qinfo,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
            accum_dtype=accum_dtype,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.MatMulAttribute(qinfo[0], qinfo[1])

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_reduce(
        self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]
        axis = args_dict["axis"]
        result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            axis=axis,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.AxisAttribute(axis)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
            # Number of products - needed for compliance
            args_dict["n"] = a.shape[axis]

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_clamp(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]

        result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)

        v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]

        if error_name == ErrorIf.MaxSmallerMin:
            # Make sure the numbers are different to invoke this error
            while v[0] == v[1]:
                v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
            max_val = min(v)
            min_val = max(v)
        else:
            max_val = max(v)
            min_val = min(v)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            max_val=max_val,
            min_val=min_val,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
            if a.dtype == DType.FP16:
                # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
                min_val = min_val.astype(np.float32)
                max_val = max_val.astype(np.float32)
            min_val_as_bytes = struct.pack("<f", min_val)
            max_val_as_bytes = struct.pack("<f", max_val)
        elif a.dtype in (DType.INT8, DType.INT16):
            min_val_as_bytes = struct.pack("<i", min_val)
            max_val_as_bytes = struct.pack("<i", max_val)
        else:
            # to avoid internal error for incorrect input types
            min_val_as_bytes = struct.pack("<i", 0)
            max_val_as_bytes = struct.pack("<i", 0)

        attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
        result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
        attr = ts.TosaSerializerAttribute()

        attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))

        self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
        return result_tens

    # Needs an additional type/input
    def build_prelu(self, op, a, validator_fcns=None, error_name=None):
        result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)

        self.ser.addOperator(op["op"], [a.name], [result_tens.name])
        return result_tens

    def build_activation(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]

        result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_concat(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        if op["op"] == Op.CONCAT_SHAPE:
            axis = 0
        else:
            axis = args_dict["axis"]
        if error_name != ErrorIf.WrongInputType:
            assert type(axis) == int

        result_tensor = OutputShaper.concatOp(
            self.ser, self.rng, axis, inputs, error_name=error_name
        )

        input_tensor_names = []
        for tensor in inputs:
            input_tensor_names.append(tensor.name)

        # Invalidate Input/Output list for error if checks.
        input_list = input_tensor_names
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            axis=axis,
            input_shape=inputs[0].shape,
            output_shape=result_tensor.shape,
            input_dtype=inputs[0].dtype,
            output_dtype=result_tensor.dtype,
            inputs=inputs,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        if op["op"] == Op.CONCAT:
            attr = ts.TosaSerializerAttribute()
            attr.AxisAttribute(axis)
        else:
            assert op["op"] == Op.CONCAT_SHAPE
            attr = None
        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, inputs[0].dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_pad(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 2
        a = inputs[0]
        pad_input = inputs[1]
        padding = args_dict["pad"]
        pad_const_int = args_dict["pad_const_int"]
        pad_const_float = args_dict["pad_const_fp"]

        result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)

        # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
        if gtu.dtypeIsFloat(a.dtype):
            pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
        else:
            pad_const_val_as_bytes = struct.pack("<i", pad_const_int)

        attr = ts.TosaSerializerAttribute()
        attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, pad_input.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            pad=padding,
            qinfo=qinfo,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
            input1=a,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_dim(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 1
        a = inputs[0]
        axis = args_dict["axis"]
        result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            axis=axis,
            input_shape=a.shape,
            input_dtype=a.dtype,
            output_shape=result_tensor.shape,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.AxisAttribute(axis)

        self.ser.addOperator(op["op"], input_list, output_list, attr)
        return TosaTestGen.BuildInfo(result_tensor, None)

    def build_reshape(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        a = inputs[0]
        shape = inputs[1]
        shape_attr = args_dict["new_shape"]
        result_tensor = OutputShaper.reshapeOp(
            self.ser, self.rng, a, shape_attr, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, shape.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_reverse(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]
        axis = args_dict["axis"]
        result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            axis=axis,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.AxisAttribute(axis)

        self.ser.addOperator(op["op"], input_list, output_list, attr)
        return TosaTestGen.BuildInfo(result_tensor, None)

    def build_transpose(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]
        perms = args_dict["perms"]

        result_tensor = OutputShaper.transposeOp(
            self.ser, self.rng, a, perms, error_name
        )

        attr = ts.TosaSerializerAttribute()
        attr.TransposeAttribute(perms)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            perms=perms,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
            input1=a,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_slice(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 3
        a, start_var, size_var = inputs
        start_const = args_dict["start"]
        size_const = args_dict["size"]

        result_tensor = OutputShaper.sliceOp(
            self.ser, self.rng, a, start_const, size_const, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, start_var.name, size_var.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            start=start_const,
            size=size_const,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
            input1=a,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_tile(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        a = inputs[0]
        multiples = inputs[1]
        multiples_attr = args_dict["multiples"]
        result_tensor = OutputShaper.tileOp(
            self.ser, self.rng, a, multiples_attr, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, multiples.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=a.shape,
            output_shape=result_tensor.shape,
            input_dtype=a.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
            input1=a,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_gather(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        values, indices = inputs

        result_tensor = OutputShaper.gatherOp(
            self.ser, self.rng, values, indices, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [values.name, indices.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=values.shape,
            output_shape=result_tensor.shape,
            input_dtype=values.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, values.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_scatter(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 3
        values_in, indices, input = inputs
        result_tensor = OutputShaper.scatterOp(
            self.ser, self.rng, values_in, indices, input, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [values_in.name, indices.name, input.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=values_in.shape,
            output_shape=result_tensor.shape,
            input_dtype=values_in.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, values_in.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_resize(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 4
        input = inputs[0]
        scale_input = inputs[1]
        offset_input = inputs[2]
        border_input = inputs[3]

        mode = args_dict["mode"]
        scale = args_dict["scale"]
        offset = args_dict["offset"]
        border = args_dict["border"]
        output_dtype = args_dict["output_dtype"]

        result_tensor = OutputShaper.resizeOp(
            self.ser,
            self.rng,
            input,
            mode,
            scale,
            offset,
            border,
            input.dtype,
            output_dtype,
            error_name,
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [
            input.name,
            scale_input.name,
            offset_input.name,
            border_input.name,
        ]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            mode=mode,
            scale=scale,
            input_dtype=input.dtype,
            output_dtype=output_dtype,
            input_shape=input.shape,
            output_shape=result_tensor.shape,
            offset=offset,
            border=border,
            input_list=input_list,
            output_list=output_list,
            result_tensors=[result_tensor],
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        # write empty scale/offset/border into ResizeAttribute
        attr.ResizeAttribute([], [], [], mode)
        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, input.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
        result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
        result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
        self.ser.addOperator(
            op, [val.name, val2.name], [result_tens.name, result_tens2.name]
        )
        return result_tens

    def build_const(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        val = inputs[0]
        self.ser.addOutputTensor(val)

        compliance = self.tensorComplianceMetaData(
            op, val.dtype, args_dict, val, error_name
        )

        return TosaTestGen.BuildInfo(val, compliance)

    # Type Conversion
    def build_cast(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        val = inputs[0]
        out_dtype = args_dict["out_type"]

        result_tensor = OutputShaper.typeConversionOp(
            self.ser, self.rng, val, out_dtype, error_name
        )

        # Invalidate Input/Output list for error if checks.
        input_list = [val.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=val.shape,
            output_shape=result_tensor.shape,
            input_dtype=val.dtype,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(op["op"], input_list, output_list)

        compliance = self.tensorComplianceMetaData(
            op, val.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_rescale(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 3
        val = inputs[0]
        multiplier_val = inputs[1]
        shift_val = inputs[2]
        out_dtype = args_dict["output_dtype"]
        scale32 = args_dict["scale"]
        double_round = args_dict["double_round"]
        per_channel = args_dict["per_channel"]
        shift_arr = args_dict["shift"]
        multiplier_arr = args_dict["multiplier"]

        result_tensor = OutputShaper.typeConversionOp(
            self.ser, self.rng, val, out_dtype, error_name
        )

        if per_channel:
            nc = val.shape[-1]
        else:
            nc = 1

        in_type_width = gtu.dtypeWidth(val.dtype)
        out_type_width = gtu.dtypeWidth(out_dtype)

        input_unsigned = False
        output_unsigned = False

        if val.dtype == DType.INT8:
            input_zp = self.randInt(-128, 128)
            in_type_width += 1
        elif val.dtype == DType.UINT8:
            input_zp = self.randInt(0, 256)
            in_type_width += 1
            input_unsigned = True
        elif error_name in [
            ErrorIf.InputZeroPointNotZero,
            ErrorIf.U16InputZeroPointNotValid,
        ]:
            input_zp = self.randInt(-128, 128)
            if input_zp == 0:
                input_zp = input_zp + self.rng.integers(1, 10)
            in_type_width += 1
        elif val.dtype == DType.UINT16:
            # Must come after ErrorIf.U16InputZeroPointNotValid check
            input_zp = self.rng.choice([0, 32768])
            in_type_width += 1
            input_unsigned = True
        else:
            input_zp = 0

        if out_dtype == DType.INT8:
            output_zp = self.randInt(-128, 128)
            out_type_width += 1
        elif out_dtype == DType.UINT8:
            output_zp = self.randInt(0, 256)
            out_type_width += 1
            output_unsigned = True
        elif error_name in [
            ErrorIf.OutputZeroPointNotZero,
            ErrorIf.U16OutputZeroPointNotValid,
        ]:
            output_zp = self.randInt(-128, 128)
            if output_zp == 0:
                output_zp = output_zp + self.rng.integers(1, 10)
            out_type_width += 1
        elif out_dtype == DType.UINT16:
            # Must come after ErrorIf.U16OutputZeroPointNotValid check
            output_zp = self.rng.choice([0, 32768])
            out_type_width += 1
            output_unsigned = True
        else:
            output_zp = 0

        min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
        max_shift_value_arr = np.int64(np.zeros(shape=[nc]))

        for i in range(nc):
            min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
            max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1

        logger.debug(
            f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
        )
        if scale32 and error_name is None:
            # Make sure random values are within apply_scale_32 specification
            # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
            assert val.placeholderFilename
            values = np.load(
                os.path.join(self.basePath, self.testPath, val.placeholderFilename)
            )
            val_adj = np.subtract(values, input_zp, dtype=np.int64)
            val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
            val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
            val_adj = np.add(val_adj, input_zp, dtype=np.int64)
            # Check we can safely convert to the expected dtype
            assert (
                val_adj.all() >= np.iinfo(values.dtype).min
                and val_adj.all() <= np.iinfo(values.dtype).max
            )

            # Force casting to output datatype
            val_adj = val_adj.astype(values.dtype, casting="unsafe")

            if not np.all(np.array_equal(values, val_adj)):
                # Values changed so overwrite file with new values
                np.save(
                    os.path.join(self.basePath, self.testPath, val.placeholderFilename),
                    val_adj,
                    False,
                )

        # Invalidate Input/Output list for error if checks.
        input_list = [val.name, multiplier_val.name, shift_val.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        qinfo = (input_zp, output_zp)
        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_dtype=val.dtype,
            output_dtype=out_dtype,
            input_shape=val.shape,
            qinfo=qinfo,
            scale32=scale32,
            double_round=double_round,
            input_list=input_list,
            output_list=output_list,
            result_tensors=[result_tensor],
            num_operands=num_operands,
        ):
            return None

        attr = ts.TosaSerializerAttribute()
        attr.RescaleAttribute(
            input_zp,
            output_zp,
            scale32,
            double_round,
            per_channel,
            input_unsigned,
            output_unsigned,
        )

        self.ser.addOperator(op["op"], input_list, output_list, attr)

        compliance = self.tensorComplianceMetaData(
            op, val.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def _get_condition_tensor(self, op, cond, error_name):
        if error_name == ErrorIf.CondIfCondNotMatchingBool:
            cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
        else:
            cond_type = DType.BOOL
        if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
            choice = self.rng.choice([1, 2])
            if choice == 1:
                cond_shape = [2]
            else:
                cond_shape = [1, 2]
        else:
            # Must be of size 1 (rank 0)
            cond_shape = []
        cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
        return cond_tens

    def build_cond_if_const(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        # For cond_if with constants, we're supplied with then/else tensors that we ignore
        # (except for the generated shape) and the condition.  Build Then/Else blocks
        # and fill them with const nodes for the body.
        assert len(inputs) == 2
        then_tens, else_tens = inputs

        cond = args_dict["condition"]

        # Condition tensor
        cond_tens = self._get_condition_tensor(op, cond, error_name)

        # Make then/else tensors
        out_shape = then_tens.shape

        dtype = DType.INT32

        # Create an incorrect output shape for error_if tests
        if error_name in [
            ErrorIf.CondIfOutputListThenGraphMismatch,
            ErrorIf.CondIfOutputListElseGraphMismatch,
        ]:
            incorrect_shape = deepcopy(then_tens.shape)
            for i in range(len(incorrect_shape)):
                incorrect_shape[i] += (
                    self.rng.choice([-3, -2, 2, 3])
                    if incorrect_shape[i] > 3
                    else self.rng.choice([1, 2, 4])
                )
            incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))

        then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
        else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))

        # And the result tensor based on any of the outputs
        result_tensor = self.ser.addOutput(out_shape, dtype)

        # Create the attribute with the names of the then/else blocks
        then_block = "THEN_BLOCK"
        else_block = "ELSE_BLOCK"
        attr = ts.TosaSerializerAttribute()
        attr.CondIfAttribute(then_block, else_block)

        # Finally, build the op and the two blocks
        self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)

        self.ser.addBasicBlock(then_block)
        # Build the actual then/else tensors inside their blocks
        if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
            then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
        else:
            then_tens = self.ser.addConst(out_shape, dtype, then_arr)
        self.ser.addOutputTensor(then_tens)

        self.ser.addBasicBlock(else_block)
        if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
            else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
        else:
            else_tens = self.ser.addConst(out_shape, dtype, else_arr)
        self.ser.addOutputTensor(else_tens)

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            basicBlocks=self.ser.currRegion.basicBlocks,
            cond=cond_tens,
        ):
            return None

        compliance = self.tensorComplianceMetaData(
            op, dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_cond_if_binary(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        # For cond_if with a binary op in the then/else blocks, take a and b and
        # alternately add or subtract them based on the condition
        assert len(inputs) == 2
        a, b = inputs

        cond = args_dict["condition"]

        # Condition tensor
        cond_tens = self._get_condition_tensor(op, cond, error_name)

        result_tensor = self.ser.addOutput(a.shape, a.dtype)

        # Create the attribute with the names of the then/else blocks
        then_block = "THEN_BLOCK"
        else_block = "ELSE_BLOCK"
        attr = ts.TosaSerializerAttribute()
        attr.CondIfAttribute(then_block, else_block)

        if error_name in [
            ErrorIf.CondIfInputListThenGraphMismatch,
            ErrorIf.CondIfInputListElseGraphMismatch,
            ErrorIf.CondIfOutputListElseGraphMismatch,
            ErrorIf.CondIfOutputListThenGraphMismatch,
        ]:
            incorrect_shape = a.shape.copy()
            for i in range(len(incorrect_shape)):
                incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
            incorrect_block_input = deepcopy(a)
            incorrect_block_input.shape = incorrect_shape

        # Finally, build the op and the two blocks
        self.ser.addOperator(
            op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
        )

        if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
            then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
        elif a.dtype in (DType.INT8, DType.INT16):
            then_op, else_op = (
                self.TOSA_OP_LIST["logical_right_shift"],
                self.TOSA_OP_LIST["logical_left_shift"],
            )
        else:
            assert False, f"No tests for DType: {a.dtype}"

        # Determine the element-wise binary operation that compliance will need to
        # check the results of
        compliance_op = then_op if cond else else_op

        for block, block_op in ((then_block, then_op), (else_block, else_op)):
            self.ser.addBasicBlock(block)
            if (
                error_name == ErrorIf.CondIfInputListThenGraphMismatch
                and block == then_block
            ) or (
                error_name == ErrorIf.CondIfInputListElseGraphMismatch
                and block == else_block
            ):
                self.ser.addInputTensor(incorrect_block_input)
                self.ser.addInputTensor(b)
                tens = self.ser.addOutput(a.shape, a.dtype)
            elif (
                error_name == ErrorIf.CondIfOutputListThenGraphMismatch
                and block == then_block
            ) or (
                error_name == ErrorIf.CondIfOutputListElseGraphMismatch
                and block == else_block
            ):
                self.ser.addInputTensor(a)
                self.ser.addInputTensor(b)
                tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
            else:
                self.ser.addInputTensor(a)
                self.ser.addInputTensor(b)
                tens = self.ser.addOutput(a.shape, a.dtype)
            self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            a=a,
            b=b,
            basicBlocks=self.ser.currRegion.basicBlocks,
            cond=cond_tens,
        ):
            return None

        compliance = self.tensorComplianceMetaData(
            compliance_op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def build_while_loop(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 1
        a = inputs[0]
        iter_val = args_dict["iterations"]

        iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])

        cond_block = "COND_BLOCK"
        body_block = "BODY_BLOCK"

        attr = ts.TosaSerializerAttribute()
        attr.WhileLoopAttribute(cond_block, body_block)

        # Accumulator tensor
        # acc = self.ser.addOutput(a.shape, a.dtype)
        acc_init_val = np.int32(np.zeros(a.shape))
        acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)

        # Intermediate/output tensors for everything going through the loop
        iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
        a_out = self.ser.addIntermediate(a.shape, a.dtype)
        if error_name == ErrorIf.InputListOutputListMismatch:
            incorrect_acc = deepcopy(acc)
            for i in range(len(incorrect_acc.shape)):
                incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
            acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
        else:
            acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)

        # While_loop operator
        self.ser.addOperator(
            op["op"],
            [iter.name, a.name, acc.name],
            [iter_out.name, a_out.name, acc_out.name],
            attr,
        )
        self.ser.addOutputTensor(acc_out)

        if error_name in [
            ErrorIf.InputListCondGraphMismatch,
            ErrorIf.InputListBodyGraphInputMismatch,
            ErrorIf.InputListBodyGraphOutputMismatch,
        ]:
            incorrect_iter = deepcopy(iter)
            for i in range(len(incorrect_iter.shape)):
                incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
            if len(incorrect_iter.shape) == 0:
                incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))

            incorrect_acc = deepcopy(acc)
            for i in range(len(incorrect_acc.shape)):
                incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])

        # COND block (input: iter, output: cond_tens )
        self.ser.addBasicBlock(cond_block)

        if error_name == ErrorIf.InputListCondGraphMismatch:
            self.ser.addInputTensor(incorrect_iter)
            self.ser.addInputTensor(a)
            self.ser.addInputTensor(incorrect_acc)
        else:
            self.ser.addInputTensor(iter)
            self.ser.addInputTensor(a)
            self.ser.addInputTensor(acc)
        zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])

        if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
            cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
        else:
            cond_type = DType.BOOL
        if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
            choice = self.rng.choice([1, 2])
            if choice == 1:
                cond_shape = [3]
            else:
                cond_shape = [1, 2]
        else:
            cond_shape = []
        cond_tens = self.ser.addOutput(cond_shape, cond_type)

        self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])

        # BODY block (input: a, acc, iter, output: a, acc, iter)
        # Note that local intermediate tensors need to be declared here for the outputs
        self.ser.addBasicBlock(body_block)

        if error_name == ErrorIf.InputListBodyGraphInputMismatch:
            self.ser.addInputTensor(incorrect_iter)
            self.ser.addInputTensor(a)
            self.ser.addInputTensor(incorrect_acc)
        else:
            self.ser.addInputTensor(iter)
            self.ser.addInputTensor(a)
            self.ser.addInputTensor(acc)

        one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])

        if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
            iter_body_out = self.ser.addIntermediate(
                incorrect_iter.shape, incorrect_iter.dtype
            )
            acc_body_out = self.ser.addIntermediate(
                incorrect_acc.shape, incorrect_acc.dtype
            )
        else:
            iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
            acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)

        self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
        self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
        self.ser.addOutputTensor(iter_body_out)
        self.ser.addOutputTensor(a)
        self.ser.addOutputTensor(acc_body_out)

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            basicBlocks=self.ser.currRegion.basicBlocks,
        ):
            return None

        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, acc_out, error_name
        )

        return TosaTestGen.BuildInfo(acc_out, compliance)

    def build_fft2d(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 2
        val1, val2 = inputs
        inverse = args_dict["inverse"]

        results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)

        input_names = [val1.name, val2.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount

        output_names = [res.name for res in results]
        output_shapes = [res.shape for res in results]
        output_dtypes = [res.dtype for res in results]

        input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_names, output_names
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            inverse=inverse,
            input1=val1,
            input2=val2,
            input_shape=val1.shape,
            input_dtype=val1.dtype,
            output_shape=output_shapes,
            output_dtype=output_dtypes,
            result_tensors=results,
            input_list=input_names,
            output_list=output_names,
            num_operands=num_operands,
        ):
            return None

        # TODO - Test local_bound, for now set local bound attribute to False
        local_bound = False

        attr = ts.TosaSerializerAttribute()
        attr.FFTAttribute(inverse, local_bound)

        self.ser.addOperator(op["op"], input_names, output_names, attr)

        compliance = []
        for res in results:
            compliance.append(
                self.tensorComplianceMetaData(
                    op, val1.dtype, args_dict, res, error_name
                )
            )

        return TosaTestGen.BuildInfo(results, compliance)

    def build_rfft2d(
        self,
        op,
        inputs,
        args_dict,
        validator_fcns=None,
        error_name=None,
        qinfo=None,
    ):
        assert len(inputs) == 1
        val = inputs[0]
        results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)

        input_names = [val.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount

        output_names = [res.name for res in results]
        output_shapes = [res.shape for res in results]
        output_dtypes = [res.dtype for res in results]

        input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_names, output_names
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input_shape=val.shape,
            input_dtype=val.dtype,
            output_shape=output_shapes,
            output_dtype=output_dtypes,
            result_tensors=results,
            input_list=input_names,
            output_list=output_names,
            num_operands=num_operands,
        ):
            return None

        # TODO - Test local_bound, for now set local bound attribute to False
        local_bound = False

        attr = ts.TosaSerializerAttribute()
        attr.RFFTAttribute(local_bound)

        self.ser.addOperator(op["op"], input_names, output_names, attr)

        compliance = []
        for res in results:
            compliance.append(
                self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
            )

        return TosaTestGen.BuildInfo(results, compliance)

    def build_shape_op(
        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
    ):
        assert len(inputs) == 2
        a, b = inputs

        result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)

        # Invalidate Input/Output list for error if checks.
        input_list = [a.name, b.name]
        output_list = [result_tensor.name]
        pCount, cCount = op["operands"]
        num_operands = pCount + cCount
        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
            self, error_name, input_list, output_list
        )

        if not TosaErrorValidator.evValidateErrorIfs(
            self.ser,
            validator_fcns,
            error_name,
            op=op,
            input1=a,
            input2=b,
            input_shape=a.shape,
            input_dtype=a.dtype,
            output_shape=result_tensor.shape,
            output_dtype=result_tensor.dtype,
            result_tensors=[result_tensor],
            input_list=input_list,
            output_list=output_list,
            num_operands=num_operands,
        ):
            return None

        self.ser.addOperator(
            op["op"],
            input_list,
            output_list,
        )
        compliance = self.tensorComplianceMetaData(
            op, a.dtype, args_dict, result_tensor, error_name
        )

        return TosaTestGen.BuildInfo(result_tensor, compliance)

    def create_filter_lists(
        self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
    ):
        # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
        default_test_rank_range = range(1, 5)
        if not shapeFilter:
            shapeFilter = [None]

        # Calculate the filters based on what is requested and what the operator allows
        rmin, rmax = op["rank"]
        if rankFilter is not None:
            cleanRankFilter = []
            # Ensure rankFilter values are allowed by operator
            for rank in rankFilter:
                if rank >= rmin and rank <= rmax:
                    cleanRankFilter.append(rank)
        elif rankFilter is None and shapeFilter[0] is None:
            # Ensure default behaviour is bounded by default range or by operator,
            # whichever is the smaller range of ranks.
            opRankRange = range(rmin, rmax + 1)
            cleanRankFilter = (
                opRankRange
                if len(opRankRange) <= len(default_test_rank_range)
                else default_test_rank_range
            )
        else:
            cleanRankFilter = range(rmin, rmax + 1)

        dtypes = op["types"]

        if dtypeFilter is not None:
            cleanDtypeFilter = []
            # Create list of operator dtypes filtered by requested dtypes
            for dtype in dtypes:
                if dtype in dtypeFilter or (
                    isinstance(dtype, list) and dtype[0] in dtypeFilter
                ):
                    cleanDtypeFilter.append(dtype)
        else:
            cleanDtypeFilter = dtypes

        if testType == "positive":
            filterDict = {
                "shapeFilter": shapeFilter,
                "rankFilter": cleanRankFilter,
                "dtypeFilter": cleanDtypeFilter,
            }
            return filterDict
        elif testType == "negative":
            if validator is not None:
                validator_info = validator(check=False, op=op)
            else:
                return None

            error_arguments = validator_info["param_reqs"]

            # Set parameters as required
            if error_arguments["rank"] is not None:
                rankFilter = error_arguments["rank"]
            else:
                rankFilter = cleanRankFilter

            if error_arguments["dtype"] is not None:
                dtypeFilter = error_arguments["dtype"]
            else:
                dtypeFilter = cleanDtypeFilter

            if error_arguments["shape"] is not None:
                shapeFilter = error_arguments["shape"]
            else:
                shapeFilter = shapeFilter[
                    :2
                ]  # Reduce number of shapes to keep test numbers small

            filterDict = {
                "shapeFilter": shapeFilter,
                "rankFilter": rankFilter,
                "dtypeFilter": dtypeFilter,
            }
            return filterDict

    def genOpTestList(
        self,
        opName,
        shapeFilter=[None],
        rankFilter=None,
        dtypeFilter=None,
        testType="positive",
    ):

        try:
            op = self.TOSA_OP_LIST[opName]
        except KeyError:
            raise Exception("Cannot find op with name {}".format(opName))

        # Initialize a new random number generator
        self.rng = np.random.default_rng(self.random_seed)

        _, tgen_fcn, _, agen_fcn = op["build_fcn"]

        # Test list consists of a tuple of:
        # (opName, testNameStr, dtype, shapeList, argumentsList)
        testList = []
        if testType == "negative" and "error_if_validators" in op:
            error_if_validators = op["error_if_validators"]
        else:
            error_if_validators = [None]

        for validator in error_if_validators:
            if validator is not None:
                error_name = validator(check=False, op=op)["error_name"]
            else:
                error_name = None

            filterDict = self.create_filter_lists(
                op, shapeFilter, rankFilter, dtypeFilter, testType, validator
            )
            if filterDict is None:
                return []
            cleanRankFilter = filterDict["rankFilter"]
            cleanDtypeFilter = filterDict["dtypeFilter"]
            cleanShapeFilter = filterDict["shapeFilter"]
            logger.debug(
                f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
            )

            for r in cleanRankFilter:
                for t in cleanDtypeFilter:
                    for shape in cleanShapeFilter:
                        # Filter out by rank
                        if shape is not None and len(shape) != r:
                            continue
                        self.setTargetShape(shape)
                        shapeList = tgen_fcn(self, op, r, error_name)

                        shapeStr = self.shapeStr(shapeList[0])
                        typeStr = self.typeStr(t)

                        # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
                        argList = []
                        if agen_fcn:
                            argList = agen_fcn(self, opName, shapeList, t, error_name)
                        else:
                            argList = [("", [])]

                        for argStr, args in argList:
                            if testType == "positive":
                                if argStr:
                                    testStr = "{}_{}_{}_{}".format(
                                        opName, shapeStr, typeStr, argStr
                                    )
                                else:
                                    testStr = "{}_{}_{}".format(
                                        opName, shapeStr, typeStr
                                    )
                            elif testType == "negative":
                                if argStr:
                                    testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
                                        opName, error_name, shapeStr, typeStr, argStr
                                    )
                                else:
                                    testStr = "{}_ERRORIF_{}_{}_{}".format(
                                        opName, error_name, shapeStr, typeStr
                                    )

                            testList.append(
                                (opName, testStr, t, error_name, shapeList, args)
                            )

        if testType == "positive":
            # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
            if "invalid_test_validators" in op:
                invalid_test_validators = op["invalid_test_validators"]
                clean_testList = []
                for test in testList:
                    remove_test = False
                    for validator_fcn in invalid_test_validators:
                        if validator_fcn(
                            opName=test[0],
                            input_dtype=test[2],
                            shapeList=test[4],
                            args=test[5],
                        ):
                            remove_test = True
                    if not remove_test:
                        clean_testList.append(test)
                testList = clean_testList

        return testList

    def serializeTest(
        self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
    ):
        try:
            op = self.TOSA_OP_LIST[opName]
        except KeyError:
            raise Exception("Cannot find op with name {}".format(opName))

        logger.info(f"Creating {testStr}")

        # Create a serializer
        self.createSerializer(opName, testStr)

        build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
        if "error_if_validators" in op:
            error_if_validators = op["error_if_validators"]
        else:
            error_if_validators = None

        pCount, cCount = op["operands"]
        num_operands = pCount + cCount

        if isinstance(dtype_or_dtypeList, list):
            dtypeList = dtype_or_dtypeList
        elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
            dtypeList = [dtype_or_dtypeList] * len(shapeList)
        else:
            dtypeList = [dtype_or_dtypeList] * (num_operands)

        if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
            assert (
                len(shapeList) == num_operands
            ), "shapeList length {} must match number of operands {}".format(
                len(shapeList), num_operands
            )
            assert (
                len(dtypeList) == num_operands
            ), "dtypeList length {} must match number of operands {}".format(
                len(dtypeList), num_operands
            )

        try:
            qgen = op["qgen"]
        except KeyError:
            qgen = None

        # Build the random tensor operands and the test

        if qgen is not None:
            qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
        else:
            qinfo = None

        # Extra meta data for the desc.json
        tensMeta = {}

        # Check we are using the new interface with an argsDict dictionary
        assert isinstance(
            argsDict, dict
        ), f"{opName} is not using new tvg/build_fcn interface"

        # New interface with args info in dictionary
        assert "dg_type" in argsDict
        tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
        if tvgInfo.dataGenDict:
            tensMeta["data_gen"] = tvgInfo.dataGenDict
        tens = tvgInfo.tensorList

        result = build_fcn(
            self,
            op,
            tens,
            argsDict,
            validator_fcns=error_if_validators,
            error_name=error_name,
            qinfo=qinfo,
        )

        if result:
            # The test is valid, serialize it
            if isinstance(result, TosaTestGen.BuildInfo):
                # Add the compliance meta data (if any)
                compliance = result.getComplianceInfo()
                if compliance:
                    tensMeta["compliance"] = compliance
            self.serialize("test", tensMeta)
        else:
            # The test is not valid
            logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")

    def createDynamicOpLists(self):

        if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
            # Already created these lists (can occur when class is initialized more than once)
            return

        # Dynamically create op lists for convolutions with a list of kernel sizes
        if not self.args.level8k:
            KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
            KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
        else:
            bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
            KERNELS_2D = [[1, bigK], [bigK, 2]]
            KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]

        for k in KERNELS_2D:
            testName = "conv2d_{}x{}".format(k[0], k[1])
            self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
            self.TOSA_OP_LIST[testName]["filter"] = k
            self.TOSA_OP_LIST[testName]["template"] = False
            self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"

            testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
            self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
                "depthwise_conv2d_TEMPLATE"
            ].copy()
            self.TOSA_OP_LIST[testName]["filter"] = k
            self.TOSA_OP_LIST[testName]["template"] = False
            self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"

            testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
            self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
                "transpose_conv2d_TEMPLATE"
            ].copy()
            self.TOSA_OP_LIST[testName]["filter"] = k
            self.TOSA_OP_LIST[testName]["template"] = False
            self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"

        for k in KERNELS_3D:
            testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
            self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
            self.TOSA_OP_LIST[testName]["filter"] = k
            self.TOSA_OP_LIST[testName]["template"] = False
            self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"

        # Delete any templates after having created any dynamic ops
        # This is a two-pass operation because it's bad practice to delete
        # keys from dictionaries while iterating
        keyList = []
        for k in self.TOSA_OP_LIST:
            try:
                if self.TOSA_OP_LIST[k]["template"]:
                    keyList.append(k)
                    continue
            except KeyError:
                pass

        for k in keyList:
            del self.TOSA_OP_LIST[k]

    def initOpListDefaults(self):
        """Fill in default fields for ops if they aren't already specified.
        Look for missing required fields (datastructure linting)."""
        for op in self.TOSA_OP_LIST:

            # Required fields
            try:
                pl, c = self.TOSA_OP_LIST[op]["operands"]
            except (KeyError, ValueError, TypeError):
                raise Exception(
                    "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
                )

            try:
                fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
            except (KeyError, ValueError, TypeError):
                raise Exception(
                    "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
                        op
                    )
                )

            try:
                _ = self.TOSA_OP_LIST[op]["types"]
            except KeyError:
                raise Exception(
                    "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
                )

            try:
                _ = self.TOSA_OP_LIST[op]["op"]
            except KeyError:
                raise Exception(
                    "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
                )

            # Put in default rank range, if missing
            try:
                _ = self.TOSA_OP_LIST[op]["rank"]
            except KeyError:
                self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE

    # Tensor operator list
    #  'op': op name
    #  'operands': tuple of (placeholder, const) operands
    #  'rank': optional, restricts rank to tuple inclusive of (min, max),
    #    if not specified, defaults to (1, 4)
    #  'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
    #  'types': array of datatypes to be tested
    TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]

    TYPE_INT = [DType.INT8, DType.INT16, DType.INT32]  # Excludes INT4
    TYPE_INT_FP = [
        DType.INT8,
        DType.INT16,
        DType.INT32,
        DType.FP16,
        DType.BF16,
        DType.FP32,
    ]  # Excludes INT4

    TYPE_BOOL = [DType.BOOL]
    TYPE_FI32 = [
        DType.FP32,
        DType.FP16,
        DType.BF16,
        DType.INT32,
    ]  # floating-types and INT32
    TYPE_FIB = [
        DType.FP16,
        DType.BF16,
        DType.FP32,
        DType.INT8,
        DType.INT16,
        DType.INT32,
        DType.BOOL,
    ]
    TYPE_FI16 = [DType.FP32, DType.INT16]

    TYPE_NARROW_INT_FP = [
        DType.INT8,
        DType.INT16,
        DType.FP16,
        DType.BF16,
        DType.FP32,
    ]

    # List of [Input Type 1, Input Type 2, Accumulator Type]
    TYPE_CONV = [
        [DType.INT8, DType.INT4, DType.INT32],
        [DType.INT8, DType.INT8, DType.INT32],
        [DType.INT16, DType.INT8, DType.INT48],
        [DType.FP16, DType.FP16, DType.FP16],
        [DType.FP16, DType.FP16, DType.FP32],
        [DType.BF16, DType.BF16, DType.FP32],
        [DType.FP32, DType.FP32, DType.FP32],
        [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
        [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
    ]

    DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)

    TOSA_OP_LIST = {
        # Tensor operators
        "argmax": {
            "op": Op.ARGMAX,
            "operands": (1, 0),
            "rank": (1, 6),
            "build_fcn": (
                build_argmax,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evArgmaxOutputRankMismatch,
                TosaErrorValidator.evArgmaxOutputShapeMismatch,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "avg_pool2d": {
            "op": Op.AVG_POOL2D,
            "operands": (1, 0),
            "rank": (4, 4),
            "build_fcn": (
                build_pool2d,
                TosaTensorGen.tgNHWC,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agPooling,
            ),
            "qgen": TosaQuantGen.qgUnary,
            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
            "error_if_validators": (
                TosaErrorValidator.evKernelSmallerOne,
                TosaErrorValidator.evStrideSmallerOne,
                TosaErrorValidator.evPadSmallerZero,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evOutputZeroPointNotZero,
                TosaErrorValidator.evPadLargerEqualKernel,
                TosaErrorValidator.evPoolingOutputShapeMismatch,
                TosaErrorValidator.evPoolingOutputShapeNonInteger,
                TosaErrorValidator.evWrongAccumulatorType,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
        },
        # Templated operator.  Filled in by createDynamicOpLists
        "conv2d_TEMPLATE": {
            "op": Op.CONV2D,
            "operands": (1, 2),
            "rank": (4, 4),
            "build_fcn": (
                build_conv2d,
                TosaTensorGen.tgConv2D,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agConv,
            ),
            "qgen": TosaQuantGen.qgConv,
            "types": TYPE_CONV,
            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evWeightZeroPointNotZero,
                TosaErrorValidator.evPadSmallerZero,
                TosaErrorValidator.evStrideSmallerOne,
                TosaErrorValidator.evDilationSmallerOne,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evConvOutputShapeMismatch,
                TosaErrorValidator.evConvOutputShapeNonInteger,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
            "template": True,
        },
        # Templated operator.  Filled in by createDynamicOpLists
        "conv3d_TEMPLATE": {
            "op": Op.CONV3D,
            "operands": (1, 2),
            "rank": (5, 5),
            "build_fcn": (
                build_conv3d,
                TosaTensorGen.tgConv3D,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agConv,
            ),
            "qgen": TosaQuantGen.qgConv,
            "types": TYPE_CONV,
            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evWeightZeroPointNotZero,
                TosaErrorValidator.evPadSmallerZero,
                TosaErrorValidator.evStrideSmallerOne,
                TosaErrorValidator.evDilationSmallerOne,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evConvOutputShapeMismatch,
                TosaErrorValidator.evConvOutputShapeNonInteger,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
            "template": True,
        },
        # Templated operator.  Filled in by createDynamicOpLists
        "depthwise_conv2d_TEMPLATE": {
            "op": Op.DEPTHWISE_CONV2D,
            "operands": (1, 2),
            "filter": [1, 1],
            "rank": (4, 4),
            "build_fcn": (
                build_depthwise_conv2d,
                TosaTensorGen.tgDepthwiseConv2D,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agConv,
            ),
            "qgen": TosaQuantGen.qgConv,
            "types": TYPE_CONV,
            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evWeightZeroPointNotZero,
                TosaErrorValidator.evPadSmallerZero,
                TosaErrorValidator.evStrideSmallerOne,
                TosaErrorValidator.evDilationSmallerOne,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evConvOutputShapeMismatch,
                TosaErrorValidator.evConvOutputShapeNonInteger,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
            "template": True,
        },
        "fully_connected": {
            "op": Op.FULLY_CONNECTED,
            "operands": (1, 2),
            "rank": (2, 2),
            "build_fcn": (
                build_fully_connected,
                TosaTensorGen.tgFullyConnected,
                TosaTensorValuesGen.tvgFullyConnected,
                TosaArgGen.agFullyConnected,
            ),
            "qgen": TosaQuantGen.qgConv,
            "types": TYPE_CONV,
            "error_if_validators": (
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evWeightZeroPointNotZero,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
        },
        "matmul": {
            "op": Op.MATMUL,
            "operands": (2, 0),
            "rank": (3, 3),
            "build_fcn": (
                build_matmul,
                TosaTensorGen.tgMatmul,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agMatMul,
            ),
            "qgen": TosaQuantGen.qgMatmul,
            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
        },
        "max_pool2d": {
            "op": Op.MAX_POOL2D,
            "operands": (1, 0),
            "rank": (4, 4),
            "build_fcn": (
                build_pool2d,
                TosaTensorGen.tgNHWC,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agPooling,
            ),
            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
            "error_if_validators": (
                TosaErrorValidator.evKernelSmallerOne,
                TosaErrorValidator.evStrideSmallerOne,
                TosaErrorValidator.evPadSmallerZero,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evPadLargerEqualKernel,
                TosaErrorValidator.evPoolingOutputShapeMismatch,
                TosaErrorValidator.evPoolingOutputShapeNonInteger,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        # Templated operator.  Filled in by createDynamicOpLists
        "transpose_conv2d_TEMPLATE": {
            "op": Op.TRANSPOSE_CONV2D,
            "operands": (1, 2),
            "rank": (4, 4),
            "build_fcn": (
                build_transpose_conv2d,
                TosaTensorGen.tgTransposeConv2D,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agTransposeConv2D,
            ),
            "qgen": TosaQuantGen.qgConv,
            "types": TYPE_CONV,
            "invalid_test_validators": (
                TosaInvalidValidator.ivHeightWidthInvalid,
                TosaInvalidValidator.ivNonPositiveOutputShape,
            ),
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evWeightZeroPointNotZero,
                TosaErrorValidator.evPadLargerEqualKernel,
                TosaErrorValidator.evStrideSmallerOne,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evConvOutputShapeMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
            "template": True,
        },
        # Activation functions
        "clamp": {
            "op": Op.CLAMP,
            "operands": (1, 0),
            "build_fcn": (
                build_clamp,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_NARROW_INT_FP,
            "error_if_validators": (
                TosaErrorValidator.evMaxSmallerMin,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "sigmoid": {
            "op": Op.SIGMOID,
            "operands": (1, 0),
            "build_fcn": (
                build_activation,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "tanh": {
            "op": Op.TANH,
            "operands": (1, 0),
            "build_fcn": (
                build_activation,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {
                "abs_error_lower_bound": 0.5,
            },
        },
        "erf": {
            "op": Op.ERF,
            "operands": (1, 0),
            "build_fcn": (
                build_activation,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"ulp": 5},
        },
        # Elementwise Binary Operators
        "add": {
            "op": Op.ADD,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgAddSub,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"ulp": 0.5},
        },
        "arithmetic_right_shift": {
            "op": Op.ARITHMETIC_RIGHT_SHIFT,
            "operands": (2, 0),
            "build_fcn": (
                build_arithmetic_right_shift,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgArithmeticRightShift,
                TosaArgGen.agArithmeticRightShift,
            ),
            "types": TYPE_INT,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "bitwise_and": {
            "op": Op.BITWISE_AND,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_INT,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "bitwise_or": {
            "op": Op.BITWISE_OR,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_INT,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "bitwise_xor": {
            "op": Op.BITWISE_XOR,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_INT,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "intdiv": {
            "op": Op.INTDIV,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgIntDiv,
                TosaArgGen.agNone,
            ),
            "types": [DType.INT32],
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "logical_and": {
            "op": Op.LOGICAL_AND,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_BOOL,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "logical_left_shift": {
            "op": Op.LOGICAL_LEFT_SHIFT,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLogicalShift,
                TosaArgGen.agNone,
            ),
            "types": TYPE_INT,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "logical_right_shift": {
            "op": Op.LOGICAL_RIGHT_SHIFT,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLogicalShift,
                TosaArgGen.agNone,
            ),
            "types": TYPE_INT,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "logical_or": {
            "op": Op.LOGICAL_OR,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_BOOL,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "logical_xor": {
            "op": Op.LOGICAL_XOR,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_BOOL,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
        },
        "maximum": {
            "op": Op.MAXIMUM,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "minimum": {
            "op": Op.MINIMUM,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "mul": {
            "op": Op.MUL,
            "operands": (3, 0),
            "build_fcn": (
                build_mul,
                TosaTensorGen.tgMul,
                TosaTensorValuesGen.tvgMul,
                TosaArgGen.agMul,
            ),
            "types": TYPE_INT_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"ulp": 0.5},
        },
        "pow": {
            "op": Op.POW,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgPow,
                TosaArgGen.agPow,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "sub": {
            "op": Op.SUB,
            "operands": (2, 0),
            "build_fcn": (
                build_binary_broadcast,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgAddSub,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"ulp": 0.5},
        },
        "table": {
            "op": Op.TABLE,
            # Use the automatic generation functions to create the input array
            # but create the table tensor in the build function, as it may be
            # a different type from the input
            "operands": (1, 0),
            "build_fcn": (
                build_table,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agTable,
            ),
            "types": [DType.INT8, DType.INT16],
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
        },
        # Elementwise Unary operators
        "abs": {
            "op": Op.ABS,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
        },
        "bitwise_not": {
            "op": Op.BITWISE_NOT,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_INT,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
        },
        "ceil": {
            "op": Op.CEIL,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
            "compliance": {"ulp": 0.5},
        },
        "clz": {
            "op": Op.CLZ,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": [DType.INT32],
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
        },
        "cos": {
            "op": Op.COS,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"abs_error_normal_divisor": 2},
        },
        "exp": {
            "op": Op.EXP,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgExp,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
        },
        "floor": {
            "op": Op.FLOOR,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
            "compliance": {"ulp": 0.5},
        },
        "log": {
            "op": Op.LOG,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLogRsqrt,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
            "compliance": {"ulp": 5},
        },
        "logical_not": {
            "op": Op.LOGICAL_NOT,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_BOOL,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
        },
        "negate": {
            "op": Op.NEGATE,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgNegate,
                TosaArgGen.agNone,
            ),
            "qgen": TosaQuantGen.qgUnary,
            "types": TYPE_INT_FP,
            "error_if_validators": (
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evOutputZeroPointNotZero,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
        },
        "reciprocal": {
            "op": Op.RECIPROCAL,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
            "compliance": {"ulp": 1.0},
        },
        "rsqrt": {
            "op": Op.RSQRT,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLogRsqrt,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.FULL_RANGE,),
            },
            "compliance": {"ulp": 2},
        },
        "sin": {
            "op": Op.SIN,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"abs_error_normal_divisor": 2},
        },
        # Elementwise Ternary operators
        "select": {
            "op": Op.SELECT,
            "operands": (3, 0),
            "build_fcn": (
                build_select,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgSelect,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FIB,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        # Comparison operators
        "equal": {
            "op": Op.EQUAL,
            "operands": (2, 0),
            "build_fcn": (
                build_comparison,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgEqual,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "greater_equal": {
            "op": Op.GREATER_EQUAL,
            "operands": (2, 0),
            "build_fcn": (
                build_comparison,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "greater": {
            "op": Op.GREATER,
            "operands": (2, 0),
            "build_fcn": (
                build_comparison,
                TosaTensorGen.tgBroadcastFuzz,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FI32,
            "error_if_validators": (
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evDimensionMismatch,
                TosaErrorValidator.evBroadcastShapesMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        # Reduction operators
        "reduce_all": {
            "op": Op.REDUCE_ALL,
            "operands": (1, 0),
            "build_fcn": (
                build_reduce,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_BOOL,
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evShapeOfAxisNotOne,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
        },
        "reduce_any": {
            "op": Op.REDUCE_ANY,
            "operands": (1, 0),
            "build_fcn": (
                build_reduce,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_BOOL,
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evShapeOfAxisNotOne,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
        },
        "reduce_max": {
            "op": Op.REDUCE_MAX,
            "operands": (1, 0),
            "build_fcn": (
                build_reduce,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_INT_FP,
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evShapeOfAxisNotOne,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "reduce_min": {
            "op": Op.REDUCE_MIN,
            "operands": (1, 0),
            "build_fcn": (
                build_reduce,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_INT_FP,
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evShapeOfAxisNotOne,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "reduce_product": {
            "op": Op.REDUCE_PRODUCT,
            "operands": (1, 0),
            "build_fcn": (
                build_reduce,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgReduceProduct,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_FP,
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evShapeOfAxisNotOne,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "reduce_sum": {
            "op": Op.REDUCE_SUM,
            "operands": (1, 0),
            "build_fcn": (
                build_reduce,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgReduceSum,
                TosaArgGen.agAxis,
            ),
            "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evShapeOfAxisNotOne,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
        },
        # Data layout operators
        "concat": {
            "op": Op.CONCAT,
            "operands": (2, 0),
            "build_fcn": (
                build_concat,
                TosaTensorGen.tgConcat,
                TosaTensorValuesGen.tvgConcat,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evConcatInputRankMismatch,
                TosaErrorValidator.evConcatShapeSumMismatch,
                TosaErrorValidator.evConcatInputDimMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "pad": {
            "op": Op.PAD,
            "operands": (2, 0),
            "build_fcn": (
                build_pad,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgPad,
                TosaArgGen.agPad,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evPadSmallerZero,
                TosaErrorValidator.evPadOutputShapeMismatch,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongRank,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "dim": {
            "op": Op.DIM,
            "operands": (1, 0),
            "build_fcn": (
                build_dim,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evWrongRank,
            ),
        },
        "reshape": {
            "op": Op.RESHAPE,
            "operands": (2, 0),
            "build_fcn": (
                build_reshape,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgReshape,
                TosaArgGen.agReshape,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evTensorSizeInputOutputMismatch,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "reverse": {
            "op": Op.REVERSE,
            "operands": (1, 0),
            "build_fcn": (
                build_reverse,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agAxis,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evAxisSmallerZero,
                TosaErrorValidator.evAxisLargerRank,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "slice": {
            "op": Op.SLICE,
            "operands": (3, 0),
            "rank": (1, 6),
            "build_fcn": (
                build_slice,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgSlice,
                TosaArgGen.agSlice,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                # TODO Turn off these error categories for now as the reference
                # model cannot allocate memory space for empty tensor. We probably
                # can report an accurate error messege at the right place during
                # exeuction.
                # TosaErrorValidator.evStartSmallerZero,
                # TosaErrorValidator.evSizeSmallerEqualZero,
                TosaErrorValidator.evStartSizeOutsideBounds,
                TosaErrorValidator.evSizeOutputShapeMismatch,
                TosaErrorValidator.evInputSizeStartLengthMismatch,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evRankMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "tile": {
            "op": Op.TILE,
            "operands": (2, 0),
            "rank": (1, 6),
            "build_fcn": (
                build_tile,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgTile,
                TosaArgGen.agTile,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evWrongRank,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "transpose": {
            "op": Op.TRANSPOSE,
            "operands": (1, 0),
            "rank": (1, 6),
            "build_fcn": (
                build_transpose,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agTranspose,
            ),
            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evIndexOutsideBounds,
                TosaErrorValidator.evIndexUsedTwice,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evRankMismatch,
                TosaErrorValidator.evTensorSizeInputOutputMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        # Data nodes
        "const": {
            "op": Op.CONST,
            "operands": (0, 1),
            "build_fcn": (
                build_const,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "identity": {
            "op": Op.IDENTITY,
            "operands": (1, 0),
            "build_fcn": (
                build_unary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": TYPE_FIB + [DType.INT4, DType.INT48],
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        # Scatter/Gather
        "gather": {
            "op": Op.GATHER,
            "operands": (2, 0),
            "rank": (3, 3),
            "build_fcn": (
                build_gather,
                TosaTensorGen.tgGather,
                TosaTensorValuesGen.tvgGather,
                TosaArgGen.agNone,
            ),
            "types": (
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.FP16,
                DType.BF16,
                DType.FP32,
                DType.FP8E4M3,
                DType.FP8E5M2,
            ),
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evWrongRank,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        "scatter": {
            "op": Op.SCATTER,
            "operands": (3, 0),
            "rank": (3, 3),
            "build_fcn": (
                build_scatter,
                TosaTensorGen.tgScatter,
                TosaTensorValuesGen.tvgScatter,
                TosaArgGen.agNone,
            ),
            "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evWrongRank,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
        },
        # Image operations
        "resize": {
            "op": Op.RESIZE,
            "operands": (4, 0),
            "rank": (4, 4),
            "build_fcn": (
                build_resize,
                TosaTensorGen.tgNHWC,
                TosaTensorValuesGen.tvgResize,
                TosaArgGen.agResize,
            ),
            "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
            "invalid_test_validators": (
                TosaInvalidValidator.ivWrongDataTypeOrModeResize,
            ),
            "error_if_validators": (
                TosaErrorValidator.evMaxDimExceeded,
                TosaErrorValidator.evScaleSmallerEqualZero,
                TosaErrorValidator.evScaleNLargerMax,
                TosaErrorValidator.evScaleDLargerMax,
                TosaErrorValidator.evOffsetSmallerMin,
                TosaErrorValidator.evOffsetLargerEqualMax,
                TosaErrorValidator.evBorderSmallerMin,
                TosaErrorValidator.evBorderLargerEqualMax,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evBatchMismatch,
                TosaErrorValidator.evChannelMismatch,
                TosaErrorValidator.evResizeOutputShapeMismatch,
                TosaErrorValidator.evResizeOutputShapeNonInteger,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"relative": 0.006},
        },
        # Type conversion
        "cast": {
            "op": Op.CAST,
            "operands": (1, 0),
            "build_fcn": (
                build_cast,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgCast,
                TosaArgGen.agCast,
            ),
            "types": (
                DType.FP16,
                DType.BF16,
                DType.FP32,
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.BOOL,
                DType.FP8E4M3,
                DType.FP8E5M2,
            ),
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
            },
            "compliance": {"ulp": 0.5},
        },
        "rescale": {
            "op": Op.RESCALE,
            "operands": (3, 0),
            "build_fcn": (
                build_rescale,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgRescale,
                TosaArgGen.agRescale,
            ),
            "types": [
                DType.UINT8,
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.UINT16,
            ],
            "error_if_validators": (
                TosaErrorValidator.evInputZeroPointNotZero,
                TosaErrorValidator.evOutputZeroPointNotZero,
                TosaErrorValidator.evU16InputZeroPointNotValid,
                TosaErrorValidator.evU16OutputZeroPointNotValid,
                TosaErrorValidator.evScaleTrue,
                TosaErrorValidator.evScaleNotTrue,
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
            ),
        },
        # Custom
        # Not implemented.
        # Control flow operators
        # Two varients of cond_if, one that generates one of two constant tensors (no
        # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
        # (two inputs to the basic blocks, one output)
        "cond_if_const": {
            "op": Op.COND_IF,
            "operands": (0, 2),
            "build_fcn": (
                build_cond_if_const,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgCondIfWhileLoop,
                TosaArgGen.agCondIf,
            ),
            "types": [DType.BOOL],
            "error_if_validators": (
                TosaErrorValidator.evOutputListThenGraphMismatch,
                TosaErrorValidator.evOutputListElseGraphMismatch,
                TosaErrorValidator.evCondIfCondNotMatchingBool,
                TosaErrorValidator.evCondIfCondShapeNotSizeOne,
            ),
        },
        "cond_if_binary": {
            "op": Op.COND_IF,
            "operands": (2, 0),
            "build_fcn": (
                build_cond_if_binary,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgCondIfWhileLoop,
                TosaArgGen.agCondIf,
            ),
            "types": TYPE_INT_FP,
            "error_if_validators": (
                TosaErrorValidator.evInputListThenGraphMismatch,
                TosaErrorValidator.evInputListElseGraphMismatch,
                TosaErrorValidator.evOutputListThenGraphMismatch,
                TosaErrorValidator.evOutputListElseGraphMismatch,
                TosaErrorValidator.evCondIfCondNotMatchingBool,
                TosaErrorValidator.evCondIfCondShapeNotSizeOne,
            ),
        },
        # while_loop
        "while_loop": {
            "op": Op.WHILE_LOOP,
            "operands": (0, 1),
            "build_fcn": (
                build_while_loop,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgCondIfWhileLoop,
                TosaArgGen.agWhileLoop,
            ),
            "types": [DType.INT32],
            "error_if_validators": (
                TosaErrorValidator.evInputListOutputListMismatch,
                TosaErrorValidator.evInputListCondGraphMismatch,
                TosaErrorValidator.evInputListBodyGraphInputMismatch,
                TosaErrorValidator.evInputListBodyGraphOutputMismatch,
                TosaErrorValidator.evCondGraphOutputNotMatchingBool,
                TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
            ),
        },
        "fft2d": {
            "op": Op.FFT2D,
            "operands": (2, 0),
            "rank": (3, 3),
            "build_fcn": (
                build_fft2d,
                TosaTensorGen.tgFFT2d,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agFFT2d,
            ),
            "types": [DType.FP32],
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evBatchMismatch,
                TosaErrorValidator.evKernelNotPowerOfTwo,
                TosaErrorValidator.evFFTInputShapeMismatch,
                TosaErrorValidator.evFFTOutputShapeMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
        },
        "rfft2d": {
            "op": Op.RFFT2D,
            "operands": (1, 0),
            "rank": (3, 3),
            "build_fcn": (
                build_rfft2d,
                TosaTensorGen.tgRFFT2d,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agRFFT2d,
            ),
            "types": [DType.FP32],
            "error_if_validators": (
                TosaErrorValidator.evWrongInputType,
                TosaErrorValidator.evWrongOutputType,
                TosaErrorValidator.evWrongInputList,
                TosaErrorValidator.evWrongOutputList,
                TosaErrorValidator.evWrongRank,
                TosaErrorValidator.evBatchMismatch,
                TosaErrorValidator.evKernelNotPowerOfTwo,
                TosaErrorValidator.evFFTOutputShapeMismatch,
            ),
            "data_gen": {
                "fp": (gtu.DataGenType.DOT_PRODUCT,),
            },
        },
        # Shape
        "add_shape": {
            "op": Op.ADD_SHAPE,
            "operands": (2, 0),
            "rank": (1, 1),
            "build_fcn": (
                build_shape_op,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgAddSub,
                TosaArgGen.agNone,
            ),
            "types": [DType.SHAPE],
            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
        },
        "sub_shape": {
            "op": Op.SUB_SHAPE,
            "operands": (2, 0),
            "rank": (1, 1),
            "build_fcn": (
                build_shape_op,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgAddSub,
                TosaArgGen.agNone,
            ),
            "types": [DType.SHAPE],
            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
        },
        "mul_shape": {
            "op": Op.MUL_SHAPE,
            "operands": (2, 0),
            "rank": (1, 1),
            "build_fcn": (
                build_shape_op,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgMul,
                TosaArgGen.agNone,
            ),
            "types": [DType.SHAPE],
            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
        },
        "div_shape": {
            "op": Op.DIV_SHAPE,
            "operands": (2, 0),
            "rank": (1, 1),
            "build_fcn": (
                build_shape_op,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgIntDiv,
                TosaArgGen.agNone,
            ),
            "types": [DType.SHAPE],
            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
        },
        "concat_shape": {
            "op": Op.CONCAT_SHAPE,
            "operands": (2, 0),
            "rank": (1, 1),
            "build_fcn": (
                build_concat,
                TosaTensorGen.tgConcat,
                TosaTensorValuesGen.tvgConcat,
                TosaArgGen.agNone,
            ),
            "types": [DType.SHAPE],
            "error_if_validators": (),
        },
        "const_shape": {
            "op": Op.CONST_SHAPE,
            "operands": (0, 1),
            "rank": (1, 1),
            "build_fcn": (
                build_const,
                TosaTensorGen.tgBasic,
                TosaTensorValuesGen.tvgLazyGenDefault,
                TosaArgGen.agNone,
            ),
            "types": [DType.SHAPE],
        },
    }


class OutputShaper:
    # Methods in this class compute the expected output shape and datatype
    # for common classes of operations
    def __init__(self):
        pass

    # These methods return arguments that can be used for
    # creating a new output tensor
    @staticmethod
    def binaryBroadcastOp(ser, rng, a, b, error_name=None):
        if error_name != ErrorIf.RankMismatch:
            assert len(a.shape) == len(b.shape)
        assert a.dtype == b.dtype

        shape = []
        for i in range(len(a.shape)):
            if a.shape[i] == 1 and error_name is None:
                shape.append(b.shape[i])
            else:
                shape.append(a.shape[i])

        fuzz_idx = rng.integers(0, len(a.shape))
        if error_name == ErrorIf.DimensionMismatch:
            shape[fuzz_idx] += 1

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP16,
                DType.BF16,
                DType.FP32,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(shape, outputDType)

    @staticmethod
    def binaryNonBroadcastOp(ser, a, b):
        assert len(a.shape) == len(b.shape)
        assert a.dtype == b.dtype

        shape = []
        for i in range(len(a.shape)):
            assert a.shape[i] == b.shape[i]
            shape.append(a.shape[i])

        return ser.addOutput(shape, a.dtype)

    @staticmethod
    def unaryOp(ser, rng, a, error_name=None):
        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(a.shape, outputDType)

    @staticmethod
    def selectOp(ser, rng, cond, a, b, error_name=None):
        if error_name != ErrorIf.RankMismatch:
            assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
        assert a.dtype == b.dtype

        shape = []
        for i in range(len(cond.shape)):
            if cond.shape[i] == 1 and error_name is None:
                shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
            else:
                shape.append(cond.shape[i])

        fuzz_idx = rng.integers(0, len(a.shape))
        if error_name == ErrorIf.DimensionMismatch:
            shape[fuzz_idx] += 1

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(shape, outputDType)

    @staticmethod
    def binaryComparisonOp(ser, rng, a, b, error_name=None):
        if error_name != ErrorIf.RankMismatch:
            assert len(a.shape) == len(b.shape)
        assert a.dtype == b.dtype

        # Do broadcast
        shape = []
        for i in range(len(a.shape)):
            if a.shape[i] == 1 and len(b.shape) > i:
                shape.append(b.shape[i])
            else:
                shape.append(a.shape[i])

        fuzz_idx = rng.integers(0, len(a.shape))
        if error_name == ErrorIf.DimensionMismatch:
            shape[fuzz_idx] += 1

        if error_name == ErrorIf.WrongOutputType:
            wrong_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = DType.BOOL

        return ser.addOutput(shape, outputDType)

    @staticmethod
    def reduceOp(ser, rng, a, axis, error_name=None):
        shape = a.shape.copy()
        if error_name not in [
            ErrorIf.AxisSmallerZero,
            ErrorIf.AxisLargerRank,
            ErrorIf.ShapeOfAxisNotOne,
        ]:
            shape[axis] = 1
        if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
            shape[axis] = rng.integers(2, 10)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(shape, outputDType)

    @staticmethod
    def argmaxOp(ser, rng, a, axis, error_name=None):
        shape = a.shape.copy()

        if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
            del shape[axis]

        if error_name == ErrorIf.ArgmaxOutputRankMismatch:
            remove = rng.choice([True, False])
            if remove and len(shape) > 1:
                del shape[0]
            else:
                shape.append(1)
        elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
            for i in range(len(shape)):
                shape[i] = shape[i] + rng.integers(1, 10)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
                DType.FP8E4M3,
                DType.FP8E5M2,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = DType.INT32

        return ser.addOutput(shape, outputDType)

    @staticmethod
    def conv2dOp(
        ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
    ):

        # IFM:    NHWC
        # Filter: OHWI
        # OFM:    NHWC

        h = (
            ifm.shape[1]
            - 1
            + padding[0]
            + padding[1]
            - (filter.shape[1] - 1) * dilations[0]
        ) // strides[0] + 1

        w = (
            ifm.shape[2]
            - 1
            + padding[2]
            + padding[3]
            - (filter.shape[2] - 1) * dilations[1]
        ) // strides[1] + 1

        if error_name == ErrorIf.ConvOutputShapeMismatch:
            choices = [1, 2, 3]
            change = rng.choice(choices)
            # increment in multiples of stride to not hit non-integer error case
            if change in [1, 3]:
                h = h + (rng.choice(choices) * strides[0])
            if change in [2, 3]:
                w = w + (rng.choice(choices) * strides[1])

        ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]

        if error_name == ErrorIf.WrongInputType:
            # Pick some potentially correct output dtype if input type is incorrect
            out_dtype = DType.INT32
        else:
            out_dtype = accum_dtype

        if error_name == ErrorIf.WrongOutputType:
            if ifm.dtype == DType.FP16:
                excludes = [DType.FP16, DType.FP32]
            elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
                excludes = [DType.FP16]
            else:
                excludes = [out_dtype]
            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
            out_dtype = rng.choice(wrong_dtypes)

        return ser.addOutput(ofm_shape, out_dtype)

    @staticmethod
    def conv3dOp(
        ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
    ):

        # IFM:    NDHWC
        # Filter: ODHWI
        # OFM:    NDHWC

        d = (
            ifm.shape[1]
            - 1
            + padding[0]
            + padding[1]
            - (filter.shape[1] - 1) * dilations[0]
        ) // strides[0] + 1

        h = (
            ifm.shape[2]
            - 1
            + padding[2]
            + padding[3]
            - (filter.shape[2] - 1) * dilations[1]
        ) // strides[1] + 1

        w = (
            ifm.shape[3]
            - 1
            + padding[4]
            + padding[5]
            - (filter.shape[3] - 1) * dilations[2]
        ) // strides[2] + 1

        if error_name == ErrorIf.ConvOutputShapeMismatch:
            choices = [1, 2, 3, 4]
            change = rng.choice(choices)
            # increment in multiples of stride to not hit non-integer error case
            if change in [1, 4]:
                d = d + (rng.choice(choices) * strides[0])
            if change in [2, 4]:
                h = h + (rng.choice(choices) * strides[1])
            if change in [3, 4]:
                w = w + (rng.choice(choices) * strides[2])

        ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]

        if error_name == ErrorIf.WrongInputType:
            # Pick some potentially correct output dtype if input type is incorrect
            out_dtype = DType.INT32
        else:
            out_dtype = accum_dtype

        if error_name == ErrorIf.WrongOutputType:
            if ifm.dtype == DType.FP16:
                excludes = [DType.FP16, DType.FP32]
            else:
                excludes = [out_dtype]
            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
            out_dtype = rng.choice(wrong_dtypes)

        return ser.addOutput(ofm_shape, out_dtype)

    @staticmethod
    def depthwiseConv2dOp(
        ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
    ):
        # IFM:    NHWC
        # Filter: HWCM
        # OFM:    NHW C*M

        h = (
            ifm.shape[1]
            - 1
            + padding[0]
            + padding[1]
            - (filter.shape[0] - 1) * dilations[0]
        ) // strides[0] + 1

        w = (
            ifm.shape[2]
            - 1
            + padding[2]
            + padding[3]
            - (filter.shape[1] - 1) * dilations[1]
        ) // strides[1] + 1

        if error_name == ErrorIf.ConvOutputShapeMismatch:
            choices = [1, 2, 3]
            change = rng.choice(choices)
            # increment in multiples of stride to not hit non-integer error case
            if change in [1, 3]:
                h = h + (rng.choice(choices) * strides[0])
            if change in [2, 3]:
                w = w + (rng.choice(choices) * strides[1])

        ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]

        if error_name == ErrorIf.WrongInputType:
            # Pick some potentially correct output dtype if input type is incorrect
            out_dtype = DType.INT32
        else:
            out_dtype = accum_dtype

        if error_name == ErrorIf.WrongOutputType:
            if ifm.dtype == DType.FP16:
                excludes = [DType.FP16, DType.FP32]
            else:
                excludes = [out_dtype]
            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
            out_dtype = rng.choice(wrong_dtypes)

        return ser.addOutput(ofm_shape, out_dtype)

    @staticmethod
    def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
        # input: NHWC
        if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
            # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
            h = 1
            w = 1
        else:
            h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
            w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1

        if error_name == ErrorIf.PoolingOutputShapeMismatch:
            choices = [1, 2, 3]
            change = rng.choice(choices)
            # increment in multiples of stride to not hit non-integer error case
            if change in [1, 3]:
                h = h + (rng.choice(choices) * stride[0])
            if change in [2, 3]:
                w = w + (rng.choice(choices) * stride[1])
        ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
                DType.FP8E4M3,
                DType.FP8E5M2,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = ifm.dtype

        return ser.addOutput(ofm_shape, outputDType)

    @staticmethod
    def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
        # input: N, IC
        # filter: OC, IC
        # output: N, OC

        output_shape = [input.shape[0], filter.shape[0]]

        # Validated in arg_gen (also invalidated for ErrorIf)
        out_dtype = accum_dtype

        return ser.addOutput(output_shape, out_dtype)

    @staticmethod
    def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
        # a: N, H, C
        # b: N, C, W
        # out: N, H, W

        output_shape = [a.shape[0], a.shape[1], b.shape[2]]

        if error_name == ErrorIf.WrongOutputType:
            if a.dtype == DType.INT8:
                incorrect_types = (
                    DType.INT4,
                    DType.INT8,
                    DType.INT16,
                    DType.INT48,
                    DType.FP32,
                    DType.FP16,
                    DType.BF16,
                    DType.FP8E4M3,
                    DType.FP8E5M2,
                )
            elif a.dtype == DType.INT16:
                incorrect_types = (
                    DType.INT4,
                    DType.INT8,
                    DType.INT16,
                    DType.INT32,
                    DType.FP32,
                    DType.FP16,
                    DType.BF16,
                    DType.FP8E4M3,
                    DType.FP8E5M2,
                )
            elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
                incorrect_types = (
                    DType.INT4,
                    DType.INT8,
                    DType.INT16,
                    DType.INT32,
                    DType.INT48,
                    DType.FP32,
                    DType.BF16,
                    DType.FP8E4M3,
                    DType.FP8E5M2,
                )
            elif (
                a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
            ):
                incorrect_types = (
                    DType.INT4,
                    DType.INT8,
                    DType.INT16,
                    DType.INT32,
                    DType.INT48,
                    DType.FP8E4M3,
                    DType.FP8E5M2,
                )
            out_dtype = rng.choice(a=incorrect_types)
        elif error_name == ErrorIf.WrongInputType:
            # Pick some potentially correct output dtype if input type is incorrect
            out_dtype = DType.INT32
        else:
            out_dtype = accum_dtype  # Validated in arg_gen

        return ser.addOutput(output_shape, out_dtype)

    @staticmethod
    def concatOp(ser, rng, axis, inputs, error_name=None):
        input1 = inputs[0]
        remaining_inputs = inputs[1:]

        # calculate the output shape, if possible, otherwise just use the first input shape
        output_shape = input1.shape.copy()
        if not (
            # unable to concat tensors of different ranks
            error_name == ErrorIf.ConcatInputRankMismatch
            # unable to concat tensors along an invalid axis
            or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
        ):
            for tensor in remaining_inputs:
                output_shape[axis] += tensor.shape[axis]

        if error_name == ErrorIf.ConcatShapeSumMismatch:
            output_shape[axis] += rng.integers(5, 10)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = {
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            }
            wrong_dtypes = list(all_dtypes - set([input1.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = input1.dtype

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def padOp(ser, rng, a, padding, error_name=None):

        output_shape = a.shape.copy()

        for i in range(len(output_shape)):
            output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]

        if error_name == ErrorIf.PadOutputShapeMismatch:
            bad_dim = rng.choice(range(len(output_shape)))
            output_shape[bad_dim] += rng.choice([1, 2])
        elif error_name == ErrorIf.RankMismatch:
            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def dimOp(ser, rng, a, axis, error_name=None):
        output_shape = [1]

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = DType.SHAPE

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def reshapeOp(ser, rng, a, shape, error_name=None):
        output_shape = shape.copy()

        if error_name == ErrorIf.TensorSizeInputOutputMismatch:
            for i in range(len(output_shape)):
                output_shape[i] = output_shape[i] + rng.integers(1, 10)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def sliceOp(ser, rng, input, start, size, error_name=None):

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = input.dtype

        output_shape = size.copy()
        if error_name == ErrorIf.SizeOutputShapeMismatch:
            for index in range(len(output_shape)):
                if output_shape[index] <= 2:
                    output_shape[index] = output_shape[index] + rng.choice([1, 2])
                else:
                    output_shape[index] = output_shape[index] + rng.choice(
                        [-2, -1, 1, 2]
                    )
        elif error_name == ErrorIf.InputSizeStartLengthMismatch:
            output_shape = input.shape.copy()
        elif error_name == ErrorIf.RankMismatch:
            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def tileOp(ser, rng, a, multiples, error_name=None):

        output_shape = a.shape.copy()
        assert len(multiples) == len(output_shape)

        for i in range(len(output_shape)):
            output_shape[i] = a.shape[i] * multiples[i]

        if error_name == ErrorIf.RankMismatch:
            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def transposeOp(ser, rng, a, perms, error_name=None):
        output_shape = a.shape.copy()

        assert len(perms) == len(output_shape)

        if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
            for i in range(len(output_shape)):
                output_shape[i] = a.shape[perms[i]]

        if error_name == ErrorIf.TensorSizeInputOutputMismatch:
            for i in range(len(output_shape)):
                output_shape[i] += rng.integers(1, 10)
        elif error_name == ErrorIf.RankMismatch:
            output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = a.dtype

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def gatherOp(ser, rng, values, indices, error_name=None):
        if error_name != ErrorIf.WrongRank:
            assert len(values.shape) == 3
        assert len(indices.shape) == 2
        assert values.shape[0] == indices.shape[0]

        output_shape = [values.shape[0], indices.shape[1], values.shape[2]]

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = values.dtype

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def scatterOp(ser, rng, values_in, indices, input, error_name=None):
        if error_name != ErrorIf.WrongRank:
            assert len(values_in.shape) == 3
        assert len(indices.shape) == 2
        assert len(input.shape) == 3
        assert values_in.shape[0] == indices.shape[0]  # N
        assert input.shape[1] == indices.shape[1]  # W
        assert values_in.shape[2] == input.shape[2]  # C

        output_shape = values_in.shape

        if error_name == ErrorIf.WrongOutputType:
            all_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
                DType.FP8E4M3,
                DType.FP8E5M2,
            ]
            wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = values_in.dtype

        return ser.addOutput(output_shape, outputDType)

    @staticmethod
    def tableOp(ser, rng, input, error_name=None):
        # Same shape as the input, dtype dependent on input dtype
        if error_name != ErrorIf.WrongInputType:
            assert input.dtype == DType.INT16 or input.dtype == DType.INT8
        output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
        if error_name == ErrorIf.WrongOutputType:
            wrong_dtypes = [
                DType.INT8,
                DType.INT16,
                DType.INT32,
                DType.INT48,
                DType.FP32,
                DType.FP16,
                DType.BF16,
            ]
            wrong_dtypes.remove(output_dtype)
            output_dtype = rng.choice(wrong_dtypes)
        return ser.addOutput(input.shape, output_dtype)

    @staticmethod
    def resizeOp(
        serializer,
        rng,
        input,
        mode,
        scale,
        offset,
        border,
        input_dtype,
        output_dtype,
        error_name=None,
    ):
        # Calculate OH, OW
        scale_y_n = scale[0]
        scale_y_d = scale[1]
        scale_x_n = scale[2]
        scale_x_d = scale[3]
        if error_name == ErrorIf.ScaleSmallerEqualZero:
            scale_y_n = max(scale_y_n, 1)
            scale_y_d = max(scale_y_d, 1)
            scale_x_n = max(scale_x_n, 1)
            scale_x_d = max(scale_x_d, 1)

        oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
        ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1

        if error_name is not None:
            # Make sure the output tensor is valid, which can occur when
            # scale, offset or border have been changed for ERROR_IFs
            oh = max(oh, 1)
            ow = max(ow, 1)
            if error_name != ErrorIf.MaxDimExceeded:
                oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
                ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)

        if error_name == ErrorIf.ResizeOutputShapeMismatch:
            choices = [1, 2, 3]
            change = rng.choice(choices)
            # increment in multiples of scale_y/x_d so we don't hit non-integer error case
            if change in [1, 3]:
                if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
                    oh -= scale_y_d
                    assert oh > 0  # Should have been caught in agResize
                else:
                    oh += scale_y_d
            if change in [2, 3]:
                if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
                    ow -= scale_x_d
                    assert ow > 0  # Should have been caught in agResize
                else:
                    ow += scale_x_d

        if error_name == ErrorIf.WrongRank:
            output_dims = [
                input.shape[0],
                oh,
                ow,
                input.shape[0],
            ]
        elif error_name == ErrorIf.BatchMismatch:
            output_dims = [
                input.shape[0] + rng.integers(1, 10),
                oh,
                ow,
                input.shape[3],
            ]
        elif error_name == ErrorIf.ChannelMismatch:
            output_dims = [
                input.shape[0],
                oh,
                ow,
                input.shape[3] + rng.integers(1, 10),
            ]
        else:
            output_dims = [input.shape[0], oh, ow, input.shape[3]]

        return serializer.addOutput(output_dims, output_dtype)

    @staticmethod
    def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
        return ser.addOutput(val.shape, out_dtype)

    @staticmethod
    def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
        if error_name == ErrorIf.ConvOutputShapeMismatch:
            choices = [1, 2, 3]
            change = rng.choice(choices)
            if change in [1, 3]:
                output_shape[1] = output_shape[1] + rng.choice(choices)
            if change in [2, 3]:
                output_shape[2] = output_shape[2] + rng.choice(choices)

        if error_name == ErrorIf.WrongInputType:
            # Pick some potentially correct output dtype if input type is incorrect
            out_dtype = DType.INT32
        else:
            out_dtype = accum_dtype

        if error_name == ErrorIf.WrongOutputType:
            if ifm.dtype == DType.FP16:
                excludes = [DType.FP16, DType.FP32]
            else:
                excludes = [out_dtype]
            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
            out_dtype = rng.choice(wrong_dtypes)

        return ser.addOutput(output_shape, out_dtype)

    @staticmethod
    def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
        outputs = []

        assert ifm1.dtype == ifm2.dtype
        input_dtype = ifm1.dtype

        if error_name != ErrorIf.FFTInputShapeMismatch:
            assert ifm1.shape == ifm2.shape

        input_shape = ifm1.shape
        if error_name != ErrorIf.WrongRank:
            assert len(input_shape) == 3

        output_shape = input_shape.copy()
        output_dtype = input_dtype

        if error_name == ErrorIf.WrongOutputType:
            excludes = [DType.FP32]
            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
            output_dtype = rng.choice(wrong_dtypes)
        elif error_name == ErrorIf.BatchMismatch:
            output_shape[0] += rng.integers(1, 10)
        elif error_name == ErrorIf.FFTOutputShapeMismatch:
            modify_dim = rng.choice([1, 2])
            output_shape[modify_dim] += rng.integers(1, 10)

        outputs.append(serializer.addOutput(output_shape, output_dtype))
        outputs.append(serializer.addOutput(output_shape, output_dtype))
        return outputs

    @staticmethod
    def rfft2dOp(serializer, rng, value, error_name=None):
        outputs = []

        input_shape = value.shape
        if error_name != ErrorIf.WrongRank:
            assert len(input_shape) == 3

        output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]

        output_dtype = value.dtype
        if error_name == ErrorIf.WrongOutputType:
            excludes = [DType.FP32]
            wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
            output_dtype = rng.choice(wrong_dtypes)
        elif error_name == ErrorIf.BatchMismatch:
            output_shape[0] += rng.integers(1, 10)
        elif error_name == ErrorIf.FFTOutputShapeMismatch:
            modify_dim = rng.choice([1, 2])
            output_shape[modify_dim] += rng.integers(1, 10)

        outputs.append(serializer.addOutput(output_shape, output_dtype))
        outputs.append(serializer.addOutput(output_shape, output_dtype))
        return outputs

    @staticmethod
    def addShapeOp(ser, rng, a, b, error_name=None):
        if error_name != ErrorIf.RankMismatch:
            assert len(a.shape) == len(b.shape)
        assert a.dtype == b.dtype

        shape = []
        for i in range(len(a.shape)):
            shape.append(a.shape[i])

        fuzz_idx = rng.integers(0, len(a.shape))
        if error_name == ErrorIf.DimensionMismatch:
            shape[fuzz_idx] += 1

        if error_name == ErrorIf.WrongOutputType:
            wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
            outputDType = rng.choice(wrong_dtypes)
        else:
            outputDType = DType.SHAPE
        return ser.addOutput(shape, outputDType)
