blob: cbac081e4c807d4dae5e56888e1805d503d6833a [file] [log] [blame]
# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
from copy import deepcopy
from datetime import datetime
from pathlib import Path
import generator.tosa_utils as gtu
import numpy as np
import serializer.tosa_serializer as ts
from generator.datagenerator import GenerateLibrary
from generator.tosa_arg_gen import TosaArgGen
from generator.tosa_arg_gen import TosaQuantGen
from generator.tosa_arg_gen import TosaTensorGen
from generator.tosa_arg_gen import TosaTensorValuesGen
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_random_gen import TosaHashRandomGenerator
from generator.tosa_random_gen import TosaRandomGenerator
from schemavalidation.schemavalidation import TestDescSchemaValidator
from tosa.DType import DType
from tosa.Op import Op
TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
// SPDX-License-Identifier: Apache-2.0
// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
"""
logging.basicConfig()
logger = logging.getLogger("tosa_verif_build_tests")
class TosaTestGen:
# This currently matches the 8K level defined in the specification.
TOSA_8K_LEVEL_MAX_SCALE = 64
TOSA_8K_LEVEL_MAX_KERNEL = 8192
TOSA_8K_LEVEL_MAX_STRIDE = 8192
# Main compliance dot product statistical test range
TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
TOSA_MI_DOT_PRODUCT_MIN = 1000
def __init__(self, args):
self.args = args
self.basePath = args.output_dir
self.random_seed = args.random_seed
self.ser = None
self.createDynamicOpLists()
self.initOpListDefaults()
self.quantGen = TosaQuantGen()
self.global_rng = None
# Force makeShape to do a specific starting shape
self.targetted_shape = None
# JSON schema validation
self.descSchemaValidator = TestDescSchemaValidator()
# Data generator library is sometimes needed for compliance set up
# even if we are generating the data later (lazy_data_generation)
self.dgl = GenerateLibrary(args.generate_lib_path)
# Work out floating point range
def convertFPRange(rangeFP, maxFP):
# Converts program arguments of max/-max to FP max
vals = []
for v in rangeFP:
if v == "max":
v = maxFP
elif v == "-max":
v = -maxFP
elif v < 0:
# Trim to minimum data type value
v = max(v, -maxFP)
elif v > 0:
# Trim to maximum data type value
v = min(v, maxFP)
vals.append(v)
return tuple(sorted(vals))
self.random_dtype_range = {
DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
}
for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
self.random_dtype_range[dtype] = convertFPRange(
args.tensor_fp_value_range,
TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
)
self.resetGlobalRNG()
def resetGlobalRNG(self):
self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
fullPath = os.path.join(self.basePath, self.testPath)
os.makedirs(fullPath, exist_ok=True)
# Embed const data in the flatbuffer
constMode = ts.ConstMode.EMBED
if self.args.lazy_data_gen:
# Lazy data generation - so make constants files
constMode = ts.ConstMode.INPUTS
elif self.args.dump_consts:
constMode = ts.ConstMode.EMBED_DUMP
self.ser = ts.TosaSerializer(fullPath, constMode)
def getSerializer(self):
return self.ser
def serialize(self, testName, metaData=None, tags=None):
path = Path(self.basePath) / self.testPath
# Write out TOSA flatbuffer binary
path_fb = path / f"{testName}.tosa"
with path_fb.open("wb") as fd:
fd.write(self.ser.serialize())
# Get JSON descriptor from serializer
desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
if metaData:
# Add extra meta data to desc.json
desc["meta"] = metaData
if tags:
desc["tag"] = tags
# Validate desc.json before we output it
self.descSchemaValidator.validate_config(desc)
if metaData:
if "data_gen" in metaData:
if self.args.lazy_data_gen:
# Output datagen meta data as CPP data
path_md = path / f"{testName}_meta_data_gen.cpp"
with path_md.open("w") as fd:
fd.write(TOSA_AUTOGENERATED_HEADER)
fd.write("// Test meta data for data generation setup\n\n")
fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
json.dump(metaData["data_gen"], fd)
fd.write(')";\n\n')
if "compliance" in metaData:
# Output datagen meta data as CPP data
path_md = path / f"{testName}_meta_compliance.cpp"
with path_md.open("w") as fd:
fd.write(TOSA_AUTOGENERATED_HEADER)
fd.write("// Test meta data for compliance validation\n\n")
fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
json.dump(metaData["compliance"], fd)
fd.write(')";\n\n')
# Write desc.json
path_desc = path / "desc.json"
with path_desc.open("w") as fd:
json.dump(desc, fd, indent=1)
def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
placeholders = []
assert len(shape_list) == len(dtype_list)
arr = None
for idx, shape in enumerate(shape_list):
if not self.args.lazy_data_gen:
arr = rng.randTensor(shape, dtype_list[idx])
placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
return placeholders
def buildConstTensors(self, rng, shape_list, dtype_list):
consts = []
assert len(shape_list) == len(dtype_list)
arr = None
for idx, shape in enumerate(shape_list):
if not self.args.lazy_data_gen:
arr = rng.randTensor(shape, dtype_list[idx])
consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
return consts
def makeShape(self, rng, rank):
if self.targetted_shape:
return np.int32(self.targetted_shape)
else:
return np.int32(
rng.integers(
low=self.args.tensor_shape_range[0],
high=self.args.tensor_shape_range[1],
size=rank,
)
)
def setTargetShape(self, shape):
self.targetted_shape = shape
def shapeStr(self, shape):
assert shape is not None
if len(shape) > 0:
# Rank > 0
return "x".join([str(d) for d in shape])
else:
# Rank 0
return "0"
def typeStr(self, dtype):
if isinstance(dtype, list) or isinstance(dtype, tuple):
assert len(dtype) >= 2
strs = [self.typeStr(t) for t in dtype]
# Limit types to the first 2 as the 3rd is the accumulator
return "x".join(strs[:2])
else:
if dtype in gtu.DTYPE_ATTRIBUTES:
return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
else:
raise Exception(
"Unknown dtype, cannot convert to string: {}".format(dtype)
)
def constrictBatchSize(self, shape):
# Limit the batch size unless an explicit target shape set
if self.args.max_batch_size and not self.args.target_shapes:
shape[0] = min(shape[0], self.args.max_batch_size)
return shape
def makeDimension(self, rng):
return rng.randInt(
low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
)
def tensorComplianceMetaData(
self, op, inputType, argsDict, outputTensor, errorName
):
# TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
UNSUPPORTED_NON_FP32_INPUT_OPS = (
Op.MATMUL,
Op.CONV2D,
Op.FULLY_CONNECTED,
Op.DEPTHWISE_CONV2D,
Op.TRANSPOSE_CONV2D,
Op.CONV3D,
)
if (
errorName
or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
or (
not gtu.dtypeIsSupportedByCompliance(inputType)
and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
)
):
# No compliance for error tests or unsupported types currently
return None
# Create compliance meta data for expected output tensor
compliance_tens = {
"mode": None,
# Data type is needed for all FP runs, as refmodel precise mode produces FP64
"data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
}
op_compliance = op.get("compliance", {})
mode = None
# Check what data generation we have done
if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
"s": argsDict["s"],
"ks": (
int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
),
}
elif "ulp" in op_compliance:
mode = gtu.ComplianceMode.ULP
compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
elif "relative" in op_compliance:
mode = gtu.ComplianceMode.RELATIVE
compliance_tens["relative_info"] = {
"max": argsDict["max_abs_value"],
"scale": op["compliance"]["relative"],
}
elif op["op"] == Op.REDUCE_PRODUCT:
mode = gtu.ComplianceMode.REDUCE_PRODUCT
compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
mode = gtu.ComplianceMode.ABS_ERROR
if "abs_error_lower_bound" in op_compliance:
compliance_tens["abs_error_info"] = {
"lower_bound": op["compliance"]["abs_error_lower_bound"]
}
elif op["op"] in (Op.SIN, Op.COS):
mode = gtu.ComplianceMode.ABS_ERROR
normal_divisor = op_compliance.get("abs_error_normal_divisor", 1)
bound_addition = op_compliance.get("abs_error_bound_addition", 0)
compliance_tens["abs_error_info"] = {
"normal_divisor": normal_divisor,
"bound_as_magnitude": True,
"bound_addition": bound_addition,
}
elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]:
# Use special mode that only checks for matching inf/nan/zeroes
# as normal values need statistical analysis
mode = gtu.ComplianceMode.FP_SPECIAL
else:
mode = gtu.ComplianceMode.EXACT
else:
mode = gtu.ComplianceMode.EXACT
compliance_tens["mode"] = gtu.ComplianceMode(mode).name
return compliance_tens
# Build Op functions
# Create the output tensor (calling OutputShaper as needed)
# Do final tweaks to attributes (if necessary for errorIf)
# Add Op into graph
# Return resulting tensor information or BuildInfo
class BuildInfo:
"""Enhanced build information containing result tensor and associated compliance dict."""
def __init__(self, resultTensor, complianceDict):
if isinstance(resultTensor, list):
assert complianceDict is None or isinstance(complianceDict, list)
self.resultTensorList = resultTensor
self.complianceDictList = complianceDict
else:
self.resultTensorList = [resultTensor]
if complianceDict is None:
self.complianceDictList = None
else:
self.complianceDictList = [complianceDict]
def getComplianceInfo(self):
if self.complianceDictList is None:
return None
else:
tens_dict = {}
for tens, comp in zip(self.resultTensorList, self.complianceDictList):
if comp is not None:
tens_dict[tens.name] = comp
if tens_dict:
# Have some compliance data, so return the info
compliance = {
"version": "0.1",
"tensors": tens_dict,
}
else:
compliance = None
return compliance
def build_unary(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
assert not isinstance(op, int)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongOutputType:
if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
qinfo = [
TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
TosaQuantGen.getZeroPoint(
rng, self.args.zeropoint, result_tensor.dtype
),
]
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
qinfo=qinfo,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = None
if op["op"] == Op.NEGATE:
attr = ts.TosaSerializerAttribute()
attr.NegateAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_binary_broadcast(
self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
):
assert len(inputs) == 2
a, b = inputs
result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_arithmetic_right_shift(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
a, b = inputs
round = args_dict["round"]
result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.ArithmeticRightShiftAttribute(round)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_mul(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
# Note that mul is binary operator but it has a shift value tensor
assert len(inputs) == 3
a, b, s = inputs
result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
# Special for multiply: Force the result to INT32 for INT types
if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
result_tensor.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
outputDType = rng.choice(all_dtypes)
result_tensor.setDtype(outputDType)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name, s.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_table(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
table = args_dict["table"]
result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
attr = ts.TosaSerializerAttribute()
attr.TableAttribute(table)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_select(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
cond, a, b = inputs
result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [cond.name, a.name, b.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=cond,
input2=a,
input3=b,
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(
op["op"],
input_list,
output_list,
)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_comparison(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
a, b = inputs
result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_shape=a.shape,
input_dtype=a.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(
op["op"],
input_list,
output_list,
)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_argmax(
self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
):
assert len(inputs) == 1
a = inputs[0]
axis = args_dict["axis"]
result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
input_dtype=a.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, inputs[0].dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_pool2d(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
input = inputs[0]
# max_pool has no accum_dtype
accum_dtype = (
args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
)
stride = args_dict["stride"]
pad = args_dict["pad"]
kernel = args_dict["kernel"]
result_tensor = OutputShaper.pool2dOp(
self.ser, rng, input, kernel, stride, pad, error_name
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType:
if input.dtype not in [DType.INT8, DType.UINT8]:
qinfo = [
TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
TosaQuantGen.getZeroPoint(
rng, self.args.zeropoint, result_tensor.dtype
),
]
# Invalidate Input/Output list for error if checks.
input_list = [input.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=input.shape,
input_dtype=input.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
accum_dtype=accum_dtype,
kernel=kernel,
stride=stride,
pad=pad,
qinfo=qinfo,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
if qinfo is None:
qinfo = [0, 0]
attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, inputs[0].dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv2d(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
ifm, filter, bias = inputs
accum_dtype = args_dict["acc_type"]
strides = args_dict["stride"]
padding = args_dict["pad"]
dilations = args_dict["dilation"]
assert len(padding) == 4
result_tensor = OutputShaper.conv2dOp(
self.ser,
rng,
ifm,
filter,
accum_dtype,
strides,
padding,
dilations,
error_name,
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
TosaQuantGen.getZeroPoint(
rng, self.args.zeropoint, result_tensor.dtype
),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=padding,
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
accum_dtype=accum_dtype,
):
return None
# TODO - Test local_bound, for now set local bound attribute to False
local_bound = False
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(
padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, ifm.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv3d(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
ifm, filter, bias = inputs
accum_dtype = args_dict["acc_type"]
strides = args_dict["stride"]
padding = args_dict["pad"]
dilations = args_dict["dilation"]
assert len(padding) == 6
result_tensor = OutputShaper.conv3dOp(
self.ser,
rng,
ifm,
filter,
accum_dtype,
strides,
padding,
dilations,
error_name,
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
TosaQuantGen.getZeroPoint(
rng, self.args.zeropoint, result_tensor.dtype
),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=padding,
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
accum_dtype=accum_dtype,
):
return None
# TODO - Test local_bound, for now set local bound attribute to False
local_bound = False
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(
padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, ifm.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_transpose_conv2d(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
ifm, filter, bias = inputs
accum_dtype = args_dict["acc_type"]
strides = args_dict["stride"]
out_pad = args_dict["pad"]
assert len(out_pad) == 4
result_tensor = OutputShaper.transposeConv2DOp(
self.ser, rng, ifm, filter, accum_dtype, strides, out_pad, error_name
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
TosaQuantGen.getZeroPoint(
rng, self.args.zeropoint, result_tensor.dtype
),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=out_pad,
stride=strides,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
accum_dtype=accum_dtype,
):
return None
# TODO - Test local_bound, for now set local bound attribute to False
local_bound = False
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(
out_pad, strides, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, ifm.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_depthwise_conv2d(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
ifm, filter, bias = inputs
accum_dtype = args_dict["acc_type"]
strides = args_dict["stride"]
padding = args_dict["pad"]
dilations = args_dict["dilation"]
result_tensor = OutputShaper.depthwiseConv2dOp(
self.ser,
rng,
ifm,
filter,
accum_dtype,
strides,
padding,
dilations,
error_name,
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
TosaQuantGen.getZeroPoint(
rng, self.args.zeropoint, result_tensor.dtype
),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=padding,
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
accum_dtype=accum_dtype,
):
return None
# TODO - Test local_bound, for now set local bound attribute to False
local_bound = False
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(
padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, ifm.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_fully_connected(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
ifm, filter, bias = inputs
accum_dtype = args_dict["acc_type"]
result_tensor = OutputShaper.fullyConnectedOp(
self.ser, rng, ifm, filter, accum_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=ifm.shape,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
qinfo=qinfo,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
accum_dtype=accum_dtype,
):
return None
attr = ts.TosaSerializerAttribute()
attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, ifm.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_matmul(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
a, b = inputs
accum_dtype = args_dict["acc_type"]
result_tensor = OutputShaper.matmulOp(
self.ser, rng, a, b, accum_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
input_dtype=a.dtype,
input2_shape=b.shape,
input2_dtype=b.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
qinfo=qinfo,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
accum_dtype=accum_dtype,
):
return None
attr = ts.TosaSerializerAttribute()
attr.MatMulAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_reduce(
self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
):
assert len(inputs) == 1
a = inputs[0]
axis = args_dict["axis"]
result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
# Number of products - needed for compliance
args_dict["n"] = a.shape[axis]
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_clamp(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
if error_name == ErrorIf.MaxSmallerMin:
# Make sure the numbers are different to invoke this error
while v[0] == v[1]:
v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
max_val = min(v)
min_val = max(v)
else:
max_val = max(v)
min_val = min(v)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
max_val=max_val,
min_val=min_val,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
min_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [min_val])
max_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [max_val])
# align to 8 bytes
while (len(min_val_as_bytes) % 8) != 0:
min_val_as_bytes.append(0)
while (len(max_val_as_bytes) % 8) != 0:
max_val_as_bytes.append(0)
attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_activation(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_concat(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
if op["op"] == Op.CONCAT_SHAPE:
axis = 0
else:
axis = args_dict["axis"]
if error_name != ErrorIf.WrongInputType:
assert type(axis) == int
result_tensor = OutputShaper.concatOp(
self.ser, rng, axis, inputs, error_name=error_name
)
input_tensor_names = []
for tensor in inputs:
input_tensor_names.append(tensor.name)
# Invalidate Input/Output list for error if checks.
input_list = input_tensor_names
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=inputs[0].shape,
output_shape=result_tensor.shape,
input_dtype=inputs[0].dtype,
output_dtype=result_tensor.dtype,
inputs=inputs,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
if op["op"] == Op.CONCAT:
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
else:
assert op["op"] == Op.CONCAT_SHAPE
attr = None
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, inputs[0].dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_pad(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
a = inputs[0]
pad_input = inputs[1]
padding = args_dict["pad"]
pad_const_int = args_dict["pad_const_int"]
pad_const_float = args_dict["pad_const_fp"]
result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
# get pad_const_val_as_bytes from either pad_const_float or pad_const_int
if gtu.dtypeIsFloat(a.dtype):
pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
a.dtype, [pad_const_float]
)
else:
pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
a.dtype, [pad_const_int]
)
# align to 8 bytes
while (len(pad_const_val_as_bytes) % 8) != 0:
pad_const_val_as_bytes.append(0)
attr = ts.TosaSerializerAttribute()
attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, pad_input.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
pad=padding,
qinfo=qinfo,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_dim(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
axis = args_dict["axis"]
result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
input_dtype=a.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return TosaTestGen.BuildInfo(result_tensor, None)
def build_reshape(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
a = inputs[0]
shape = inputs[1]
shape_attr = args_dict["new_shape"]
result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, shape.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_reverse(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
axis = args_dict["axis"]
result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return TosaTestGen.BuildInfo(result_tensor, None)
def build_transpose(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
perms = args_dict["perms"]
result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
attr = ts.TosaSerializerAttribute()
attr.TransposeAttribute(perms)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tensor.shape,
perms=perms,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_slice(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
a, start_var, size_var = inputs
start_const = args_dict["start"]
size_const = args_dict["size"]
result_tensor = OutputShaper.sliceOp(
self.ser, rng, a, start_const, size_const, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, start_var.name, size_var.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
start=start_const,
size=size_const,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_tile(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
a = inputs[0]
multiples = inputs[1]
multiples_attr = args_dict["multiples"]
result_tensor = OutputShaper.tileOp(
self.ser, rng, a, multiples_attr, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, multiples.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tensor.shape,
input_dtype=a.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_gather(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
values, indices = inputs
result_tensor = OutputShaper.gatherOp(
self.ser, rng, values, indices, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [values.name, indices.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=values.shape,
output_shape=result_tensor.shape,
input_dtype=values.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, values.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_scatter(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
values_in, indices, input = inputs
result_tensor = OutputShaper.scatterOp(
self.ser, rng, values_in, indices, input, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [values_in.name, indices.name, input.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=values_in.shape,
output_shape=result_tensor.shape,
input_dtype=values_in.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, values_in.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_resize(
self,
rng,
op,
inputs,
args_dict,
validator_fcns,
error_name=None,
qinfo=None,
):
assert len(inputs) == 4
input = inputs[0]
scale_input = inputs[1]
offset_input = inputs[2]
border_input = inputs[3]
mode = args_dict["mode"]
scale = args_dict["scale"]
offset = args_dict["offset"]
border = args_dict["border"]
output_dtype = args_dict["output_dtype"]
result_tensor = OutputShaper.resizeOp(
self.ser,
rng,
input,
mode,
scale,
offset,
border,
input.dtype,
output_dtype,
error_name,
)
# Invalidate Input/Output list for error if checks.
input_list = [
input.name,
scale_input.name,
offset_input.name,
border_input.name,
]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
mode=mode,
scale=scale,
input_dtype=input.dtype,
output_dtype=output_dtype,
input_shape=input.shape,
output_shape=result_tensor.shape,
offset=offset,
border=border,
input_list=input_list,
output_list=output_list,
result_tensors=[result_tensor],
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
# write empty scale/offset/border into ResizeAttribute
attr.ResizeAttribute([], [], [], mode)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, input.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_const(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
val = inputs[0]
self.ser.addOutputTensor(val)
compliance = self.tensorComplianceMetaData(
op, val.dtype, args_dict, val, error_name
)
return TosaTestGen.BuildInfo(val, compliance)
# Type Conversion
def build_cast(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
val = inputs[0]
out_dtype = args_dict["out_type"]
result_tensor = OutputShaper.typeConversionOp(
self.ser, rng, val, out_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [val.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=val.shape,
output_shape=result_tensor.shape,
input_dtype=val.dtype,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, val.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_rescale(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 3
val = inputs[0]
multiplier_val = inputs[1]
shift_val = inputs[2]
out_dtype = args_dict["output_dtype"]
scale32 = args_dict["scale"]
double_round = args_dict["double_round"]
per_channel = args_dict["per_channel"]
shift_arr = args_dict["shift"]
multiplier_arr = args_dict["multiplier"]
result_tensor = OutputShaper.typeConversionOp(
self.ser, rng, val, out_dtype, error_name
)
if per_channel:
nc = val.shape[-1]
else:
nc = 1
in_type_width = gtu.dtypeWidth(val.dtype)
out_type_width = gtu.dtypeWidth(out_dtype)
input_unsigned = False
output_unsigned = False
if val.dtype == DType.INT8:
input_zp = rng.randInt(-128, 128)
in_type_width += 1
elif val.dtype == DType.UINT8:
input_zp = rng.randInt(0, 256)
in_type_width += 1
input_unsigned = True
elif error_name in [
ErrorIf.InputZeroPointNotZero,
ErrorIf.U16InputZeroPointNotValid,
]:
input_zp = rng.randInt(-128, 128)
if input_zp == 0:
input_zp = input_zp + rng.integers(1, 10)
in_type_width += 1
elif val.dtype == DType.UINT16:
# Must come after ErrorIf.U16InputZeroPointNotValid check
input_zp = rng.choice([0, 32768])
in_type_width += 1
input_unsigned = True
else:
input_zp = 0
if out_dtype == DType.INT8:
output_zp = rng.randInt(-128, 128)
out_type_width += 1
elif out_dtype == DType.UINT8:
output_zp = rng.randInt(0, 256)
out_type_width += 1
output_unsigned = True
elif error_name in [
ErrorIf.OutputZeroPointNotZero,
ErrorIf.U16OutputZeroPointNotValid,
]:
output_zp = rng.randInt(-128, 128)
if output_zp == 0:
output_zp = output_zp + rng.integers(1, 10)
out_type_width += 1
elif out_dtype == DType.UINT16:
# Must come after ErrorIf.U16OutputZeroPointNotValid check
output_zp = rng.choice([0, 32768])
out_type_width += 1
output_unsigned = True
else:
output_zp = 0
min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
for i in range(nc):
min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
logger.debug(
f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
)
if scale32 and error_name is None:
# Make sure random values are within apply_scale_32 specification
# REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
assert val.placeholderFilename
values = np.load(
os.path.join(self.basePath, self.testPath, val.placeholderFilename)
)
val_adj = np.subtract(values, input_zp, dtype=np.int64)
val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
val_adj = np.add(val_adj, input_zp, dtype=np.int64)
# Check we can safely convert to the expected dtype
assert (
val_adj.all() >= np.iinfo(values.dtype).min
and val_adj.all() <= np.iinfo(values.dtype).max
)
# Force casting to output datatype
val_adj = val_adj.astype(values.dtype, casting="unsafe")
if not np.all(np.array_equal(values, val_adj)):
# Values changed so overwrite file with new values
np.save(
os.path.join(self.basePath, self.testPath, val.placeholderFilename),
val_adj,
False,
)
# Invalidate Input/Output list for error if checks.
input_list = [val.name, multiplier_val.name, shift_val.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_list, output_list
)
qinfo = (input_zp, output_zp)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=val.dtype,
output_dtype=out_dtype,
input_shape=val.shape,
qinfo=qinfo,
scale32=scale32,
double_round=double_round,
input_list=input_list,
output_list=output_list,
result_tensors=[result_tensor],
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.RescaleAttribute(
input_zp,
output_zp,
scale32,
double_round,
per_channel,
input_unsigned,
output_unsigned,
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
op, val.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def _get_condition_tensor(self, rng, op, cond, error_name):
if error_name == ErrorIf.CondIfCondNotMatchingBool:
cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
else:
cond_type = DType.BOOL
if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
choice = rng.choice([1, 2])
if choice == 1:
cond_shape = [2]
else:
cond_shape = [1, 2]
else:
# Must be of size 1 (rank 0)
cond_shape = []
cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
return cond_tens
def build_cond_if_const(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
# For cond_if with constants, we're supplied with then/else tensors that we ignore
# (except for the generated shape) and the condition. Build Then/Else blocks
# and fill them with const nodes for the body.
assert len(inputs) == 2
then_tens, else_tens = inputs
cond = args_dict["condition"]
# Condition tensor
cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
# Make then/else tensors
out_shape = then_tens.shape
dtype = DType.INT32
# Create an incorrect output shape for error_if tests
if error_name in [
ErrorIf.CondIfOutputListThenGraphMismatch,
ErrorIf.CondIfOutputListElseGraphMismatch,
]:
incorrect_shape = deepcopy(then_tens.shape)
for i in range(len(incorrect_shape)):
incorrect_shape[i] += (
rng.choice([-3, -2, 2, 3])
if incorrect_shape[i] > 3
else rng.choice([1, 2, 4])
)
incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
then_arr = np.int32(rng.integers(0, 256, size=out_shape))
else_arr = np.int32(rng.integers(0, 256, size=out_shape))
# And the result tensor based on any of the outputs
result_tensor = self.ser.addOutput(out_shape, dtype)
# Create the attribute with the names of the then/else blocks
then_block = "THEN_BLOCK"
else_block = "ELSE_BLOCK"
attr = ts.TosaSerializerAttribute()
attr.CondIfAttribute(then_block, else_block)
# Finally, build the op and the two blocks
self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
else:
then_tens = self.ser.addConst(out_shape, dtype, then_arr)
self.ser.addOutputTensor(then_tens)
self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
else:
else_tens = self.ser.addConst(out_shape, dtype, else_arr)
self.ser.addOutputTensor(else_tens)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
compliance = self.tensorComplianceMetaData(
op, dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_cond_if_binary(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
# For cond_if with a binary op in the then/else blocks, take a and b and
# alternately add or subtract them based on the condition
assert len(inputs) == 2
a, b = inputs
cond = args_dict["condition"]
# Condition tensor
cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
result_tensor = self.ser.addOutput(a.shape, a.dtype)
# Create the attribute with the names of the then/else blocks
then_block = "THEN_BLOCK"
else_block = "ELSE_BLOCK"
attr = ts.TosaSerializerAttribute()
attr.CondIfAttribute(then_block, else_block)
if error_name in [
ErrorIf.CondIfInputListThenGraphMismatch,
ErrorIf.CondIfInputListElseGraphMismatch,
ErrorIf.CondIfOutputListElseGraphMismatch,
ErrorIf.CondIfOutputListThenGraphMismatch,
]:
incorrect_shape = a.shape.copy()
for i in range(len(incorrect_shape)):
incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
incorrect_block_input = deepcopy(a)
incorrect_block_input.shape = incorrect_shape
# Finally, build the op and the two blocks
self.ser.addOperator(
op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
)
if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = (
self.TOSA_OP_LIST["logical_right_shift"],
self.TOSA_OP_LIST["logical_left_shift"],
)
else:
assert False, f"No tests for DType: {a.dtype}"
# Determine the element-wise binary operation that compliance will need to
# check the results of
compliance_op = then_op if cond else else_op
for block, block_op in ((then_block, then_op), (else_block, else_op)):
self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
and block == then_block
) or (
error_name == ErrorIf.CondIfInputListElseGraphMismatch
and block == else_block
):
self.ser.addInputTensor(incorrect_block_input)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(a.shape, a.dtype)
elif (
error_name == ErrorIf.CondIfOutputListThenGraphMismatch
and block == then_block
) or (
error_name == ErrorIf.CondIfOutputListElseGraphMismatch
and block == else_block
):
self.ser.addInputTensor(a)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
else:
self.ser.addInputTensor(a)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(a.shape, a.dtype)
self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
a=a,
b=b,
basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
compliance = self.tensorComplianceMetaData(
compliance_op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_while_loop(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
a = inputs[0]
iter_val = args_dict["iterations"]
iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
cond_block = "COND_BLOCK"
body_block = "BODY_BLOCK"
attr = ts.TosaSerializerAttribute()
attr.WhileLoopAttribute(cond_block, body_block)
# Accumulator tensor
# acc = self.ser.addOutput(a.shape, a.dtype)
acc_init_val = np.int32(np.zeros(a.shape))
acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
# Intermediate/output tensors for everything going through the loop
iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
a_out = self.ser.addIntermediate(a.shape, a.dtype)
if error_name == ErrorIf.InputListOutputListMismatch:
incorrect_acc = deepcopy(acc)
for i in range(len(incorrect_acc.shape)):
incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
else:
acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
# While_loop operator
self.ser.addOperator(
op["op"],
[iter.name, a.name, acc.name],
[iter_out.name, a_out.name, acc_out.name],
attr,
)
self.ser.addOutputTensor(acc_out)
if error_name in [
ErrorIf.InputListCondGraphMismatch,
ErrorIf.InputListBodyGraphInputMismatch,
ErrorIf.InputListBodyGraphOutputMismatch,
]:
incorrect_iter = deepcopy(iter)
for i in range(len(incorrect_iter.shape)):
incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
if len(incorrect_iter.shape) == 0:
incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
incorrect_acc = deepcopy(acc)
for i in range(len(incorrect_acc.shape)):
incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
# COND block (input: iter, output: cond_tens )
self.ser.addBasicBlock(cond_block)
if error_name == ErrorIf.InputListCondGraphMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(incorrect_acc)
else:
self.ser.addInputTensor(iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(acc)
zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
else:
cond_type = DType.BOOL
if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
choice = rng.choice([1, 2])
if choice == 1:
cond_shape = [3]
else:
cond_shape = [1, 2]
else:
cond_shape = []
cond_tens = self.ser.addOutput(cond_shape, cond_type)
self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
self.ser.addBasicBlock(body_block)
if error_name == ErrorIf.InputListBodyGraphInputMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(incorrect_acc)
else:
self.ser.addInputTensor(iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(acc)
one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
iter_body_out = self.ser.addIntermediate(
incorrect_iter.shape, incorrect_iter.dtype
)
acc_body_out = self.ser.addIntermediate(
incorrect_acc.shape, incorrect_acc.dtype
)
else:
iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
self.ser.addOutputTensor(iter_body_out)
self.ser.addOutputTensor(a)
self.ser.addOutputTensor(acc_body_out)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
basicBlocks=self.ser.currRegion.basicBlocks,
):
return None
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, acc_out, error_name
)
return TosaTestGen.BuildInfo(acc_out, compliance)
def build_fft2d(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
val1, val2 = inputs
inverse = args_dict["inverse"]
results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
input_names = [val1.name, val2.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
output_names = [res.name for res in results]
output_shapes = [res.shape for res in results]
output_dtypes = [res.dtype for res in results]
input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_names, output_names
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
inverse=inverse,
input1=val1,
input2=val2,
input_shape=val1.shape,
input_dtype=val1.dtype,
output_shape=output_shapes,
output_dtype=output_dtypes,
result_tensors=results,
input_list=input_names,
output_list=output_names,
num_operands=num_operands,
):
return None
# TODO - Test local_bound, for now set local bound attribute to False
local_bound = False
attr = ts.TosaSerializerAttribute()
attr.FFTAttribute(inverse, local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
compliance = []
for res in results:
compliance.append(
self.tensorComplianceMetaData(
op, val1.dtype, args_dict, res, error_name
)
)
return TosaTestGen.BuildInfo(results, compliance)
def build_rfft2d(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 1
val = inputs[0]
results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
input_names = [val.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
output_names = [res.name for res in results]
output_shapes = [res.shape for res in results]
output_dtypes = [res.dtype for res in results]
input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
rng, error_name, input_names, output_names
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=val.shape,
input_dtype=val.dtype,
output_shape=output_shapes,
output_dtype=output_dtypes,
result_tensors=results,
input_list=input_names,
output_list=output_names,
num_operands=num_operands,
):
return None
# TODO - Test local_bound, for now set local bound attribute to False
local_bound = False
attr = ts.TosaSerializerAttribute()
attr.RFFTAttribute(local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
compliance = []
for res in results:
compliance.append(
self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
)
return TosaTestGen.BuildInfo(results, compliance)
def build_shape_op(
self,
rng,
op,
inputs,
args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(inputs) == 2
a, b = inputs
result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_shape=a.shape,
input_dtype=a.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(
op["op"],
input_list,
output_list,
)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
)
return TosaTestGen.BuildInfo(result_tensor, compliance)
def create_filter_lists(
self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
):
# Create a default testing rank range
if testType == "positive":
# 0-3 inclusive to keep test sizes reasonably small.
default_test_rank_range = range(0, 4)
else:
# Some errors do not work with rank 0, use 1-3
default_test_rank_range = range(1, 4)
# Calculate the filters based on what is requested and what the operator allows
rmin, rmax = op["rank"]
if shapeFilter:
# Specified shapes - ignore rank filter and default to op ranks below
rankFilter = None
ranksToCheck = []
elif rankFilter is None:
# No set rank filter so ensure default behaviour is bounded
ranksToCheck = default_test_rank_range
else:
ranksToCheck = rankFilter
cleanRankFilter = []
# Ensure rank values are allowed by operator
for rank in ranksToCheck:
if rank >= rmin and rank <= rmax:
cleanRankFilter.append(rank)
if shapeFilter or (len(cleanRankFilter) == 0 and rankFilter is None):
# Shapes specified or default test ranks didn't meet
# op requirements - so just use op ranks
cleanRankFilter = range(rmin, rmax + 1)
dtypes = op["types"]
if dtypeFilter is not None:
cleanDtypeFilter = []
# Create list of operator dtypes filtered by requested dtypes
for dtype in dtypes:
if dtype in dtypeFilter or (
isinstance(dtype, list) and dtype[0] in dtypeFilter
):
cleanDtypeFilter.append(dtype)
else:
cleanDtypeFilter = dtypes
if not shapeFilter:
shapeFilter = [None]
if testType == "positive":
filterDict = {
"shapeFilter": shapeFilter,
"rankFilter": cleanRankFilter,
"dtypeFilter": cleanDtypeFilter,
}
return filterDict
elif testType == "negative":
if validator is not None:
validator_info = validator(check=False, op=op)
else:
return None
error_arguments = validator_info["param_reqs"]
# Set parameters as required
if error_arguments["rank"] is not None:
rankFilter = error_arguments["rank"]
else:
rankFilter = cleanRankFilter
if error_arguments["dtype"] is not None:
dtypeFilter = error_arguments["dtype"]
else:
dtypeFilter = cleanDtypeFilter
if error_arguments["shape"] is not None:
shapeFilter = error_arguments["shape"]
else:
shapeFilter = shapeFilter[
:2
] # Reduce number of shapes to keep test numbers small
filterDict = {
"shapeFilter": shapeFilter,
"rankFilter": rankFilter,
"dtypeFilter": dtypeFilter,
}
return filterDict
def genOpTestList(
self,
opName,
shapeFilter=[None],
rankFilter=None,
dtypeFilter=None,
testType="positive",
):
try:
op = self.TOSA_OP_LIST[opName]
except KeyError:
raise Exception("Cannot find op with name {}".format(opName))
if not self.args.stable_rng:
# Initialize a new random number generator per op
self.resetGlobalRNG()
_, tgen_fcn, _, agen_fcn = op["build_fcn"]
# Test list consists of a tuple of:
# (opName, testNameStr, dtype, shapeList, argumentsList)
testList = []
if testType == "negative" and "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
num_error_types_created = 0
else:
error_if_validators = [None]
num_error_types_created = None
for validator in error_if_validators:
if validator is not None:
error_name = validator(check=False, op=op)["error_name"]
else:
error_name = None
filterDict = self.create_filter_lists(
op, shapeFilter, rankFilter, dtypeFilter, testType, validator
)
if filterDict is None:
return []
cleanRankFilter = filterDict["rankFilter"]
cleanDtypeFilter = filterDict["dtypeFilter"]
cleanShapeFilter = filterDict["shapeFilter"]
logger.debug(
f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
)
for r in cleanRankFilter:
for t in cleanDtypeFilter:
for shape in cleanShapeFilter:
# Filter out by rank
if shape is not None and len(shape) != r:
continue
self.setTargetShape(shape)
typeStr = self.typeStr(t)
if self.args.stable_rng:
shape_rng = TosaHashRandomGenerator(
self.random_seed,
[opName, r, typeStr],
self.random_dtype_range,
)
else:
shape_rng = self.global_rng
shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
shapeStr = self.shapeStr(shapeList[0])
# Argument lists consists of tuples of the (str, []) string representation and the build function argument list
argList = []
if agen_fcn:
if self.args.stable_rng:
arg_rng = TosaHashRandomGenerator(
self.random_seed,
[opName, shapeStr, typeStr],
self.random_dtype_range,
)
else:
arg_rng = self.global_rng
argList = agen_fcn(
self, arg_rng, opName, shapeList, t, error_name
)
else:
argList = [("", [])]
for argStr, args in argList:
# Create the test name string - for example: add_1x2x3_i32
if testType == "positive":
name_parts = [opName, shapeStr, typeStr]
else:
assert testType == "negative"
name_parts = [
opName,
"ERRORIF",
error_name,
shapeStr,
typeStr,
]
if argStr:
name_parts.append(argStr)
testStr = "_".join(name_parts)
testList.append(
(opName, testStr, t, error_name, shapeList, args)
)
if error_name is not None:
# Check the last test is of the error we wanted
if len(testList) == 0 or testList[-1][3] != error_name:
if self.args.level8k:
logger.info(f"Missing {error_name} tests due to level8k mode")
else:
logger.error(f"ERROR: Failed to create any {error_name} tests")
logger.debug(
"Last test created: {}".format(
testList[-1] if testList else None
)
)
else:
# Successfully created at least one ERRROR_IF test
num_error_types_created += 1
if testType == "positive":
# Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
if "invalid_test_validators" in op:
invalid_test_validators = op["invalid_test_validators"]
clean_testList = []
for test in testList:
remove_test = False
for validator_fcn in invalid_test_validators:
if validator_fcn(
opName=test[0],
input_dtype=test[2],
shapeList=test[4],
args=test[5],
):
remove_test = True
if not remove_test:
clean_testList.append(test)
testList = clean_testList
else:
if num_error_types_created is not None and not self.args.level8k:
remaining_error_types = (
len(error_if_validators) - num_error_types_created
)
if remaining_error_types:
raise Exception(
f"Failed to create {remaining_error_types} error types for {opName}"
)
return testList
def serializeTest(
self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
):
try:
op = self.TOSA_OP_LIST[opName]
except KeyError:
raise Exception("Cannot find op with name {}".format(opName))
logger.info(f"Creating {testStr}")
# Create a serializer
self.createSerializer(opName, testStr)
build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
if "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
else:
error_if_validators = None
pCount, cCount = op["operands"]
num_operands = pCount + cCount
if isinstance(dtype_or_dtypeList, list):
dtypeList = dtype_or_dtypeList
elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
dtypeList = [dtype_or_dtypeList] * len(shapeList)
else:
dtypeList = [dtype_or_dtypeList] * (num_operands)
if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
assert (
len(shapeList) == num_operands
), "shapeList length {} must match number of operands {}".format(
len(shapeList), num_operands
)
assert (
len(dtypeList) == num_operands
), "dtypeList length {} must match number of operands {}".format(
len(dtypeList), num_operands
)
try:
qgen = op["qgen"]
except KeyError:
qgen = None
# Build the random tensor operands and the test
# Set the random number generator
if self.args.stable_rng:
build_rng = TosaHashRandomGenerator(
self.random_seed, [testStr], self.random_dtype_range
)
else:
build_rng = self.global_rng
if qgen is not None:
qinfo = qgen(
build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
)
else:
qinfo = None
# Extra meta data for the desc.json
tensMeta = {}
# Check we are using the new interface with an argsDict dictionary
assert isinstance(
argsDict, dict
), f"{opName} is not using new tvg/build_fcn interface"
# New interface with args info in dictionary
assert "dg_type" in argsDict
tvgInfo = tvgen_fcn(
self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
)
if tvgInfo.dataGenDict:
tensMeta["data_gen"] = tvgInfo.dataGenDict
tens = tvgInfo.tensorList
tags = argsDict.get("tags", None)
result = build_fcn(
self,
build_rng,
op,
tens,
argsDict,
validator_fcns=error_if_validators,
error_name=error_name,
qinfo=qinfo,
)
if result:
# The test is valid, serialize it
if isinstance(result, TosaTestGen.BuildInfo):
# Add the compliance meta data (if any)
compliance = result.getComplianceInfo()
if compliance:
tensMeta["compliance"] = compliance
self.serialize("test", tensMeta, tags)
return True
else:
# The test is not valid
logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
return False
def createDynamicOpLists(self):
# Find all the ops marked as templates
templateKeys = []
for opName in self.TOSA_OP_LIST:
try:
if self.TOSA_OP_LIST[opName]["template"]:
templateKeys.append(opName)
except KeyError:
pass
bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
# Add dynamic ops based on kernel sizes
for opName in templateKeys:
assert opName.endswith("_TEMPLATE"), "Found incorrect template"
realName = opName[: len(opName) - len("_TEMPLATE")]
template = self.TOSA_OP_LIST[opName]
k_rank = 3 if realName == "conv3d" else 2
# Choose kernels to build tests for from the template or args
if self.args.level8k:
if k_rank == 3:
kernels = [[1, bigK, 1], [2, 2, bigK]]
else:
kernels = [[1, bigK], [bigK, 2]]
else:
kernels = []
if len(self.args.conv_kernels) > 0:
kernels = [k for k in self.args.conv_kernels if len(k) == k_rank]
if len(kernels) == 0:
logger.debug(
f"{realName} op using defaults as no rank {k_rank} kernels found in {self.args.conv_kernels}"
)
if len(kernels) == 0:
# Fallback to use the defined template kernels
kernels = self.TOSA_OP_LIST[opName]["filter"]
# Dynamically create ops for listed kernel sizes
for k in kernels:
kernelStr = "x".join([str(d) for d in k])
testName = f"{realName}_{kernelStr}"
kernelOp = template.copy()
kernelOp["filter"] = k
kernelOp["template"] = False
kernelOp["real_name"] = realName
self.TOSA_OP_LIST[testName] = kernelOp
# Delete the template after having created the dynamic ops
del self.TOSA_OP_LIST[opName]
def initOpListDefaults(self):
"""Fill in default fields for ops if they aren't already specified.
Look for missing required fields (datastructure linting)."""
for op in self.TOSA_OP_LIST:
# Required fields
try:
pl, c = self.TOSA_OP_LIST[op]["operands"]
except (KeyError, ValueError, TypeError):
raise Exception(
"Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
)
try:
fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
except (KeyError, ValueError, TypeError):
raise Exception(
"Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
op
)
)
try:
_ = self.TOSA_OP_LIST[op]["types"]
except KeyError:
raise Exception(
"Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
)
try:
_ = self.TOSA_OP_LIST[op]["op"]
except KeyError:
raise Exception(
"Op {} is missing the Op field in TOSA_OP_LIST".format(op)
)
# Put in default rank range, if missing
try:
_ = self.TOSA_OP_LIST[op]["rank"]
except KeyError:
self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
# Tensor operator list
# 'op': op name
# 'operands': tuple of (placeholder, const) operands
# 'rank': optional, restricts rank to tuple inclusive of (min, max),
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
TYPE_INT_FP = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.FP16,
DType.BF16,
DType.FP32,
] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
TYPE_FI32 = [
DType.FP32,
DType.FP16,
DType.BF16,
DType.INT32,
] # floating-types and INT32
TYPE_FIB = [
DType.FP16,
DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
DType.BOOL,
]
TYPE_FI16 = [DType.FP32, DType.INT16]
TYPE_NARROW_INT_FP = [
DType.INT8,
DType.INT16,
DType.FP16,
DType.BF16,
DType.FP32,
]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
[DType.INT8, DType.INT4, DType.INT32],
[DType.INT8, DType.INT8, DType.INT32],
[DType.INT16, DType.INT8, DType.INT48],
[DType.FP16, DType.FP16, DType.FP16],
[DType.FP16, DType.FP16, DType.FP32],
[DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
[DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
[DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
]
DEFAULT_RANK_RANGE = (0, gtu.MAX_TENSOR_RANK)
KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
PSEUDO_RANDOM_DATAGEN = {
DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM,),
DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM,),
}
DOT_PRODUCT_DATAGEN = {
DType.FP16: (gtu.DataGenType.DOT_PRODUCT,),
DType.FP32: (gtu.DataGenType.DOT_PRODUCT,),
}
EW_UNARY_DATAGEN = {
DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE),
}
PR_FS_DATAGEN = {
DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
}
TOSA_OP_LIST = {
# Tensor operators
"argmax": {
"op": Op.ARGMAX,
"operands": (1, 0),
"rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_argmax,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evArgmaxOutputRankMismatch,
TosaErrorValidator.evArgmaxOutputShapeMismatch,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"avg_pool2d": {
"op": Op.AVG_POOL2D,
"operands": (1, 0),
"rank": (4, 4),
"build_fcn": (
build_pool2d,
TosaTensorGen.tgNHWC,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evPoolingOutputShapeMismatch,
TosaErrorValidator.evPoolingOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": DOT_PRODUCT_DATAGEN,
},
# Templated operator. Filled in by createDynamicOpLists
"conv2d_TEMPLATE": {
"op": Op.CONV2D,
"operands": (1, 2),
"rank": (4, 4),
"build_fcn": (
build_conv2d,
TosaTensorGen.tgConv2D,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": DOT_PRODUCT_DATAGEN,
"broadcastable_bias": True,
"filter": KERNELS_2D,
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
"conv3d_TEMPLATE": {
"op": Op.CONV3D,
"operands": (1, 2),
"rank": (5, 5),
"build_fcn": (
build_conv3d,
TosaTensorGen.tgConv3D,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": DOT_PRODUCT_DATAGEN,
"filter": KERNELS_3D,
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
"depthwise_conv2d_TEMPLATE": {
"op": Op.DEPTHWISE_CONV2D,
"operands": (1, 2),
"rank": (4, 4),
"build_fcn": (
build_depthwise_conv2d,
TosaTensorGen.tgDepthwiseConv2D,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": DOT_PRODUCT_DATAGEN,
"filter": KERNELS_2D,
"template": True,
},
"fully_connected": {
"op": Op.FULLY_CONNECTED,
"operands": (1, 2),
"rank": (2, 2),
"build_fcn": (
build_fully_connected,
TosaTensorGen.tgFullyConnected,
TosaTensorValuesGen.tvgFullyConnected,
TosaArgGen.agFullyConnected,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": DOT_PRODUCT_DATAGEN,
},
"matmul": {
"op": Op.MATMUL,
"operands": (2, 0),
"rank": (3, 3),
"build_fcn": (
build_matmul,
TosaTensorGen.tgMatmul,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": DOT_PRODUCT_DATAGEN,
},
"max_pool2d": {
"op": Op.MAX_POOL2D,
"operands": (1, 0),
"rank": (4, 4),
"build_fcn": (
build_pool2d,
TosaTensorGen.tgNHWC,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
"types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evPoolingOutputShapeMismatch,
TosaErrorValidator.evPoolingOutputShapeNonInteger,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Templated operator. Filled in by createDynamicOpLists
"transpose_conv2d_TEMPLATE": {
"op": Op.TRANSPOSE_CONV2D,
"operands": (1, 2),
"rank": (4, 4),
"build_fcn": (
build_transpose_conv2d,
TosaTensorGen.tgTransposeConv2D,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTransposeConv2D,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (
TosaInvalidValidator.ivHeightWidthInvalid,
TosaInvalidValidator.ivNonPositiveOutputShape,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": DOT_PRODUCT_DATAGEN,
"filter": KERNELS_2D,
"template": True,
},
# Activation functions
"clamp": {
"op": Op.CLAMP,
"operands": (1, 0),
"build_fcn": (
build_clamp,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_NARROW_INT_FP,
"error_if_validators": (
TosaErrorValidator.evMaxSmallerMin,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"sigmoid": {
"op": Op.SIGMOID,
"operands": (1, 0),
"build_fcn": (
build_activation,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"tanh": {
"op": Op.TANH,
"operands": (1, 0),
"build_fcn": (
build_activation,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {
"abs_error_lower_bound": 0.5,
},
},
"erf": {
"op": Op.ERF,
"operands": (1, 0),
"build_fcn": (
build_activation,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 5},
},
# Elementwise Binary Operators
"add": {
"op": Op.ADD,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgAddSub,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PR_FS_DATAGEN,
"compliance": {"ulp": 0.5},
},
"arithmetic_right_shift": {
"op": Op.ARITHMETIC_RIGHT_SHIFT,
"operands": (2, 0),
"build_fcn": (
build_arithmetic_right_shift,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgArithmeticRightShift,
TosaArgGen.agArithmeticRightShift,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_and": {
"op": Op.BITWISE_AND,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_or": {
"op": Op.BITWISE_OR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_xor": {
"op": Op.BITWISE_XOR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"intdiv": {
"op": Op.INTDIV,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgIntDiv,
TosaArgGen.agNone,
),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_and": {
"op": Op.LOGICAL_AND,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_left_shift": {
"op": Op.LOGICAL_LEFT_SHIFT,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLogicalShift,
TosaArgGen.agNone,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_right_shift": {
"op": Op.LOGICAL_RIGHT_SHIFT,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLogicalShift,
TosaArgGen.agNone,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_or": {
"op": Op.LOGICAL_OR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_xor": {
"op": Op.LOGICAL_XOR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"maximum": {
"op": Op.MAXIMUM,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PR_FS_DATAGEN,
},
"minimum": {
"op": Op.MINIMUM,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"mul": {
"op": Op.MUL,
"operands": (3, 0),
"build_fcn": (
build_mul,
TosaTensorGen.tgMul,
TosaTensorValuesGen.tvgMul,
TosaArgGen.agMul,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 0.5},
},
"pow": {
"op": Op.POW,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgPow,
TosaArgGen.agPow,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"sub": {
"op": Op.SUB,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgAddSub,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 0.5},
},
"table": {
"op": Op.TABLE,
# Use the automatic generation functions to create the input array
# but create the table tensor in the build function, as it may be
# a different type from the input
"operands": (1, 0),
"build_fcn": (
build_table,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTable,
),
"types": [DType.INT8, DType.INT16],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
# Elementwise Unary operators
"abs": {
"op": Op.ABS,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
},
"bitwise_not": {
"op": Op.BITWISE_NOT,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"ceil": {
"op": Op.CEIL,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 0.5},
},
"clz": {
"op": Op.CLZ,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"cos": {
"op": Op.COS,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {
"abs_error_normal_divisor": 2,
"abs_error_bound_addition": 1,
},
},
"exp": {
"op": Op.EXP,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgExp,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
},
"floor": {
"op": Op.FLOOR,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 0.5},
},
"log": {
"op": Op.LOG,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLogRsqrt,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 5},
},
"logical_not": {
"op": Op.LOGICAL_NOT,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"negate": {
"op": Op.NEGATE,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgNegate,
TosaArgGen.agNone,
),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
},
"reciprocal": {
"op": Op.RECIPROCAL,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 1.0},
},
"rsqrt": {
"op": Op.RSQRT,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLogRsqrt,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": EW_UNARY_DATAGEN,
"compliance": {"ulp": 2},
},
"sin": {
"op": Op.SIN,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"abs_error_normal_divisor": 2},
},
# Elementwise Ternary operators
"select": {
"op": Op.SELECT,
"operands": (3, 0),
"build_fcn": (
build_select,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgSelect,
TosaArgGen.agNone,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Comparison operators
"equal": {
"op": Op.EQUAL,
"operands": (2, 0),
"build_fcn": (
build_comparison,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgEqual,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PR_FS_DATAGEN,
},
"greater_equal": {
"op": Op.GREATER_EQUAL,
"operands": (2, 0),
"build_fcn": (
build_comparison,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PR_FS_DATAGEN,
},
"greater": {
"op": Op.GREATER,
"operands": (2, 0),
"build_fcn": (
build_comparison,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
"data_gen": PR_FS_DATAGEN,
},
# Reduction operators
"reduce_all": {
"op": Op.REDUCE_ALL,
"operands": (1, 0),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reduce_any": {
"op": Op.REDUCE_ANY,
"operands": (1, 0),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reduce_max": {
"op": Op.REDUCE_MAX,
"operands": (1, 0),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reduce_min": {
"op": Op.REDUCE_MIN,
"operands": (1, 0),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reduce_product": {
"op": Op.REDUCE_PRODUCT,
"operands": (1, 0),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgReduceProduct,
TosaArgGen.agAxis,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reduce_sum": {
"op": Op.REDUCE_SUM,
"operands": (1, 0),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
"types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": DOT_PRODUCT_DATAGEN,
},
# Data layout operators
"concat": {
"op": Op.CONCAT,
"operands": (2, 0),
"build_fcn": (
build_concat,
TosaTensorGen.tgConcat,
TosaTensorValuesGen.tvgConcat,
TosaArgGen.agAxis,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evConcatInputRankMismatch,
TosaErrorValidator.evConcatShapeSumMismatch,
TosaErrorValidator.evConcatInputDimMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"pad": {
"op": Op.PAD,
"operands": (2, 0),
"rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_pad,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgPad,
TosaArgGen.agPad,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evPadOutputShapeMismatch,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongRank,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"dim": {
"op": Op.DIM,
"operands": (1, 0),
"rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_dim,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
},
"reshape": {
"op": Op.RESHAPE,
"operands": (2, 0),
"rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_reshape,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgReshape,
TosaArgGen.agReshape,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evTensorSizeInputOutputMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"reverse": {
"op": Op.REVERSE,
"operands": (1, 0),
"build_fcn": (
build_reverse,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"slice": {
"op": Op.SLICE,
"operands": (3, 0),
"rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_slice,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgSlice,
TosaArgGen.agSlice,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
# TODO Turn off these error categories for now as the reference
# model cannot allocate memory space for empty tensor. We probably
# can report an accurate error messege at the right place during
# exeuction.
# TosaErrorValidator.evStartSmallerZero,
# TosaErrorValidator.evSizeSmallerEqualZero,
TosaErrorValidator.evStartSizeOutsideBounds,
TosaErrorValidator.evSizeOutputShapeMismatch,
TosaErrorValidator.evInputSizeStartLengthMismatch,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"tile": {
"op": Op.TILE,
"operands": (2, 0),
"rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_tile,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgTile,
TosaArgGen.agTile,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongRank,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
"rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTranspose,
),
"types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evIndexOutsideBounds,
TosaErrorValidator.evIndexUsedTwice,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evTensorSizeInputOutputMismatch,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Data nodes
"const": {
"op": Op.CONST,
"operands": (0, 1),
"build_fcn": (
build_const,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"identity": {
"op": Op.IDENTITY,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": TYPE_FIB + [DType.INT4, DType.INT48],
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Scatter/Gather
"gather": {
"op": Op.GATHER,
"operands": (2, 0),
"rank": (3, 3),
"build_fcn": (
build_gather,
TosaTensorGen.tgGather,
TosaTensorValuesGen.tvgGather,
TosaArgGen.agNone,
),
"types": (
DType.INT8,
DType.INT16,
DType.INT32,
DType.FP16,
DType.BF16,
DType.FP32,
DType.FP8E4M3,
DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
"scatter": {
"op": Op.SCATTER,
"operands": (3, 0),
"rank": (3, 3),
"build_fcn": (
build_scatter,
TosaTensorGen.tgScatter,
TosaTensorValuesGen.tvgScatter,
TosaArgGen.agNone,
),
"types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
},
# Image operations
"resize": {
"op": Op.RESIZE,
"operands": (4, 0),
"rank": (4, 4),
"build_fcn": (
build_resize,
TosaTensorGen.tgNHWC,
TosaTensorValuesGen.tvgResize,
TosaArgGen.agResize,
),
"types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
"error_if_validators": (
TosaErrorValidator.evMaxDimExceeded,
TosaErrorValidator.evScaleSmallerEqualZero,
TosaErrorValidator.evScaleNLargerMax,
TosaErrorValidator.evScaleDLargerMax,
TosaErrorValidator.evOffsetSmallerMin,
TosaErrorValidator.evOffsetLargerEqualMax,
TosaErrorValidator.evBorderSmallerMin,
TosaErrorValidator.evBorderLargerEqualMax,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evChannelMismatch,
TosaErrorValidator.evResizeOutputShapeMismatch,
TosaErrorValidator.evResizeOutputShapeNonInteger,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"relative": 0.006},
},
# Type conversion
"cast": {
"op": Op.CAST,
"operands": (1, 0),
"build_fcn": (
build_cast,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgCast,
TosaArgGen.agCast,
),
"types": (
DType.FP16,
DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
DType.BOOL,
DType.FP8E4M3,
DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
"data_gen": PSEUDO_RANDOM_DATAGEN,
"compliance": {"ulp": 0.5},
},
"rescale": {
"op": Op.RESCALE,
"operands": (3, 0),
"build_fcn": (
build_rescale,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgRescale,
TosaArgGen.agRescale,
),
"types": [
DType.UINT8,
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.UINT16,
],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
TosaErrorValidator.evU16InputZeroPointNotValid,
TosaErrorValidator.evU16OutputZeroPointNotValid,
TosaErrorValidator.evScaleTrue,
TosaErrorValidator.evScaleNotTrue,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
# Custom
# Not implemented.
# Control flow operators
# Two varients of cond_if, one that generates one of two constant tensors (no
# inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
# (two inputs to the basic blocks, one output)
"cond_if_const": {
"op": Op.COND_IF,
"operands": (0, 2),
"build_fcn": (
build_cond_if_const,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agCondIf,
),
"types": [DType.BOOL],
"error_if_validators": (
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
TosaErrorValidator.evCondIfCondNotMatchingBool,
TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
"cond_if_binary": {
"op": Op.COND_IF,
"operands": (2, 0),
"build_fcn": (
build_cond_if_binary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agCondIf,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evInputListThenGraphMismatch,
TosaErrorValidator.evInputListElseGraphMismatch,
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
TosaErrorValidator.evCondIfCondNotMatchingBool,
TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
# while_loop
"while_loop": {
"op": Op.WHILE_LOOP,
"operands": (0, 1),
"build_fcn": (
build_while_loop,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agWhileLoop,
),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evInputListOutputListMismatch,
TosaErrorValidator.evInputListCondGraphMismatch,
TosaErrorValidator.evInputListBodyGraphInputMismatch,
TosaErrorValidator.evInputListBodyGraphOutputMismatch,
TosaErrorValidator.evCondGraphOutputNotMatchingBool,
TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
"fft2d": {
"op": Op.FFT2D,
"operands": (2, 0),
"rank": (3, 3),
"build_fcn": (
build_fft2d,
TosaTensorGen.tgFFT2d,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agFFT2d,
),
"types": [DType.FP32],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evKernelNotPowerOfTwo,
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
"data_gen": DOT_PRODUCT_DATAGEN,
},
"rfft2d": {
"op": Op.RFFT2D,
"operands": (1, 0),
"rank": (3, 3),
"build_fcn": (
build_rfft2d,
TosaTensorGen.tgRFFT2d,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agRFFT2d,
),
"types": [DType.FP32],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evKernelNotPowerOfTwo,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
"data_gen": DOT_PRODUCT_DATAGEN,
},
# Shape
"add_shape": {
"op": Op.ADD_SHAPE,
"operands": (2, 0),
"rank": (1, 1),
"build_fcn": (
build_shape_op,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgAddSub,
TosaArgGen.agNone,
),
"types": [DType.SHAPE],
"error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
},
"sub_shape": {
"op": Op.SUB_SHAPE,
"operands": (2, 0),
"rank": (1, 1),
"build_fcn": (
build_shape_op,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgAddSub,
TosaArgGen.agNone,
),
"types": [DType.SHAPE],
"error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
},
"mul_shape": {
"op": Op.MUL_SHAPE,
"operands": (2, 0),
"rank": (1, 1),
"build_fcn": (
build_shape_op,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgMul,
TosaArgGen.agNone,
),
"types": [DType.SHAPE],
"error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
},
"div_shape": {
"op": Op.DIV_SHAPE,
"operands": (2, 0),
"rank": (1, 1),
"build_fcn": (
build_shape_op,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgIntDiv,
TosaArgGen.agNone,
),
"types": [DType.SHAPE],
"error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
},
"concat_shape": {
"op": Op.CONCAT_SHAPE,
"operands": (2, 0),
"rank": (1, 1),
"build_fcn": (
build_concat,
TosaTensorGen.tgConcat,
TosaTensorValuesGen.tvgConcat,
TosaArgGen.agNone,
),
"types": [DType.SHAPE],
"error_if_validators": (),
},
"const_shape": {
"op": Op.CONST_SHAPE,
"operands": (0, 1),
"rank": (1, 1),
"build_fcn": (
build_const,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
"types": [DType.SHAPE],
},
}
class OutputShaper:
# Methods in this class compute the expected output shape and datatype
# for common classes of operations
def __init__(self):
pass
# These methods return arguments that can be used for
# creating a new output tensor
@staticmethod
def binaryBroadcastOp(ser, rng, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
# Work out broadcasted output shape (when not ERRORIF test)
shape = []
for i in range(len(a.shape)):
if a.shape[i] == 1 and error_name is None:
shape.append(b.shape[i])
else:
shape.append(a.shape[i])
if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
# Can only create this error for rank > 0
fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP16,
DType.BF16,
DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(shape, outputDType)
@staticmethod
def binaryNonBroadcastOp(ser, a, b):
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
shape = []
for i in range(len(a.shape)):
assert a.shape[i] == b.shape[i]
shape.append(a.shape[i])
return ser.addOutput(shape, a.dtype)
@staticmethod
def unaryOp(ser, rng, a, error_name=None):
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(a.shape, outputDType)
@staticmethod
def selectOp(ser, rng, cond, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
assert a.dtype == b.dtype
# Work out broadcasted output shape (when not ERRORIF test)
shape = []
for i in range(len(cond.shape)):
if cond.shape[i] == 1 and error_name is None:
shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
else:
shape.append(cond.shape[i])
if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
# Can only create this error for rank > 0
fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(shape, outputDType)
@staticmethod
def binaryComparisonOp(ser, rng, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
# Work out broadcasted output shape
shape = []
for i in range(len(a.shape)):
if a.shape[i] == 1 and len(b.shape) > i:
shape.append(b.shape[i])
else:
shape.append(a.shape[i])
if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
# Can only create this error for rank > 0
fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.BOOL
return ser.addOutput(shape, outputDType)
@staticmethod
def reduceOp(ser, rng, a, axis, error_name=None):
shape = a.shape.copy()
if error_name not in [
ErrorIf.AxisSmallerZero,
ErrorIf.AxisLargerRank,
ErrorIf.ShapeOfAxisNotOne,
]:
shape[axis] = 1
if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
shape[axis] = rng.integers(2, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(shape, outputDType)
@staticmethod
def argmaxOp(ser, rng, a, axis, error_name=None):
shape = a.shape.copy()
if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
del shape[axis]
if error_name == ErrorIf.ArgmaxOutputRankMismatch:
remove = rng.choice([True, False])
if remove and len(shape) > 1:
del shape[0]
else:
shape.append(1)
elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
for i in range(len(shape)):
shape[i] = shape[i] + rng.integers(1, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
DType.FP8E4M3,
DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.INT32
return ser.addOutput(shape, outputDType)
@staticmethod
def _get_conv_output_type(input_dtype):
if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
return input_dtype
elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
return DType.FP16
elif input_dtype in (DType.INT8, DType.INT4):
return DType.INT32
elif input_dtype in (DType.INT16,):
return DType.INT48
assert True, f"Unsupported convolution data type {input_dtype}"
@staticmethod
def conv2dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
# IFM: NHWC
# Filter: OHWI
# OFM: NHWC
h = (
ifm.shape[1]
- 1
+ padding[0]
+ padding[1]
- (filter.shape[1] - 1) * dilations[0]
) // strides[0] + 1
w = (
ifm.shape[2]
- 1
+ padding[2]
+ padding[3]
- (filter.shape[2] - 1) * dilations[1]
) // strides[1] + 1
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 3]:
h = h + (rng.choice(choices) * strides[0])
if change in [2, 3]:
w = w + (rng.choice(choices) * strides[1])
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
excludes = [DType.FP16]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def conv3dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
# IFM: NDHWC
# Filter: ODHWI
# OFM: NDHWC
d = (
ifm.shape[1]
- 1
+ padding[0]
+ padding[1]
- (filter.shape[1] - 1) * dilations[0]
) // strides[0] + 1
h = (
ifm.shape[2]
- 1
+ padding[2]
+ padding[3]
- (filter.shape[2] - 1) * dilations[1]
) // strides[1] + 1
w = (
ifm.shape[3]
- 1
+ padding[4]
+ padding[5]
- (filter.shape[3] - 1) * dilations[2]
) // strides[2] + 1
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3, 4]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 4]:
d = d + (rng.choice(choices) * strides[0])
if change in [2, 4]:
h = h + (rng.choice(choices) * strides[1])
if change in [3, 4]:
w = w + (rng.choice(choices) * strides[2])
ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def depthwiseConv2dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
# IFM: NHWC
# Filter: HWCM
# OFM: NHW C*M
h = (
ifm.shape[1]
- 1
+ padding[0]
+ padding[1]
- (filter.shape[0] - 1) * dilations[0]
) // strides[0] + 1
w = (
ifm.shape[2]
- 1
+ padding[2]
+ padding[3]
- (filter.shape[1] - 1) * dilations[1]
) // strides[1] + 1
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 3]:
h = h + (rng.choice(choices) * strides[0])
if change in [2, 3]:
w = w + (rng.choice(choices) * strides[1])
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
# input: NHWC
if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
# If an incorrect stride is used set dimensions to 1, test is invalid anyway.
h = 1
w = 1
else:
h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
if error_name == ErrorIf.PoolingOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 3]:
h = h + (rng.choice(choices) * stride[0])
if change in [2, 3]:
w = w + (rng.choice(choices) * stride[1])
ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
DType.FP8E4M3,
DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = ifm.dtype
return ser.addOutput(ofm_shape, outputDType)
@staticmethod
def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
# input: N, IC
# filter: OC, IC
# output: N, OC
output_shape = [input.shape[0], filter.shape[0]]
# Validated in arg_gen (also invalidated for ErrorIf)
out_dtype = accum_dtype
return ser.addOutput(output_shape, out_dtype)
@staticmethod
def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
# a: N, H, C
# b: N, C, W
# out: N, H, W
output_shape = [a.shape[0], a.shape[1], b.shape[2]]
if error_name == ErrorIf.WrongOutputType:
if a.dtype == DType.INT8:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
DType.FP8E4M3,
DType.FP8E5M2,
)
elif a.dtype == DType.INT16:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.FP32,
DType.FP16,
DType.BF16,
DType.FP8E4M3,
DType.FP8E5M2,
)
elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.BF16,
DType.FP8E4M3,
DType.FP8E5M2,
)
elif (
a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
):
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP8E4M3,
DType.FP8E5M2,
)
out_dtype = rng.choice(a=incorrect_types)
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = accum_dtype # Validated in arg_gen
return ser.addOutput(output_shape, out_dtype)
@staticmethod
def concatOp(ser, rng, axis, inputs, error_name=None):
input1 = inputs[0]
remaining_inputs = inputs[1:]
# calculate the output shape, if possible, otherwise just use the first input shape
output_shape = input1.shape.copy()
if not (
# unable to concat tensors of different ranks
error_name == ErrorIf.ConcatInputRankMismatch
# unable to concat tensors along an invalid axis
or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
):
for tensor in remaining_inputs:
output_shape[axis] += tensor.shape[axis]
if error_name == ErrorIf.ConcatShapeSumMismatch:
output_shape[axis] += rng.integers(5, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = {
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = input1.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def padOp(ser, rng, a, padding, error_name=None):
output_shape = a.shape.copy()
for i in range(len(output_shape)):
output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
if error_name == ErrorIf.PadOutputShapeMismatch:
bad_dim = rng.choice(range(len(output_shape)))
output_shape[bad_dim] += rng.choice([1, 2])
elif error_name == ErrorIf.RankMismatch:
output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def dimOp(ser, rng, a, axis, error_name=None):
output_shape = [1]
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.SHAPE
return ser.addOutput(output_shape, outputDType)
@staticmethod
def reshapeOp(ser, rng, a, shape, error_name=None):
output_shape = shape.copy()
if error_name == ErrorIf.TensorSizeInputOutputMismatch:
for i in range(len(output_shape)):
output_shape[i] = output_shape[i] + rng.integers(1, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def sliceOp(ser, rng, input, start, size, error_name=None):
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = input.dtype
output_shape = size.copy()
if error_name == ErrorIf.SizeOutputShapeMismatch:
for index in range(len(output_shape)):
if output_shape[index] <= 2:
output_shape[index] = output_shape[index] + rng.choice([1, 2])
else:
output_shape[index] = output_shape[index] + rng.choice(
[-2, -1, 1, 2]
)
elif error_name == ErrorIf.InputSizeStartLengthMismatch:
output_shape = input.shape.copy()
elif error_name == ErrorIf.RankMismatch:
output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
return ser.addOutput(output_shape, outputDType)
@staticmethod
def tileOp(ser, rng, a, multiples, error_name=None):
output_shape = a.shape.copy()
assert len(multiples) == len(output_shape)
for i in range(len(output_shape)):
output_shape[i] = a.shape[i] * multiples[i]
if error_name == ErrorIf.RankMismatch:
output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def transposeOp(ser, rng, a, perms, error_name=None):
output_shape = a.shape.copy()
assert len(perms) == len(output_shape)
if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
for i in range(len(output_shape)):
output_shape[i] = a.shape[perms[i]]
if error_name == ErrorIf.TensorSizeInputOutputMismatch:
for i in range(len(output_shape)):
output_shape[i] += rng.integers(1, 10)
elif error_name == ErrorIf.RankMismatch:
output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def gatherOp(ser, rng, values, indices, error_name=None):
if error_name != ErrorIf.WrongRank:
assert len(values.shape) == 3
assert len(indices.shape) == 2
assert values.shape[0] == indices.shape[0]
output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = values.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def scatterOp(ser, rng, values_in, indices, input, error_name=None):
if error_name != ErrorIf.WrongRank:
assert len(values_in.shape) == 3
assert len(indices.shape) == 2
assert len(input.shape) == 3
assert values_in.shape[0] == indices.shape[0] # N
assert input.shape[1] == indices.shape[1] # W
assert values_in.shape[2] == input.shape[2] # C
output_shape = values_in.shape
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
DType.FP8E4M3,
DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = values_in.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def tableOp(ser, rng, input, error_name=None):
# Same shape as the input, dtype dependent on input dtype
if error_name != ErrorIf.WrongInputType:
assert input.dtype == DType.INT16 or input.dtype == DType.INT8
output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(input.shape, output_dtype)
@staticmethod
def resizeOp(
serializer,
rng,
input,
mode,
scale,
offset,
border,
input_dtype,
output_dtype,
error_name=None,
):
# Calculate OH, OW
scale_y_n = scale[0]
scale_y_d = scale[1]
scale_x_n = scale[2]
scale_x_d = scale[3]
if error_name == ErrorIf.ScaleSmallerEqualZero:
scale_y_n = max(scale_y_n, 1)
scale_y_d = max(scale_y_d, 1)
scale_x_n = max(scale_x_n, 1)
scale_x_d = max(scale_x_d, 1)
oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
if error_name is not None:
# Make sure the output tensor is valid, which can occur when
# scale, offset or border have been changed for ERROR_IFs
oh = max(oh, 1)
ow = max(ow, 1)
if error_name != ErrorIf.MaxDimExceeded:
oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
if error_name == ErrorIf.ResizeOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of scale_y/x_d so we don't hit non-integer error case
if change in [1, 3]:
if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
oh -= scale_y_d
assert oh > 0 # Should have been caught in agResize
else:
oh += scale_y_d
if change in [2, 3]:
if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
ow -= scale_x_d
assert ow > 0 # Should have been caught in agResize
else:
ow += scale_x_d
if error_name == ErrorIf.WrongRank:
output_dims = [
input.shape[0],
oh,
ow,
input.shape[0],
]
elif error_name == ErrorIf.BatchMismatch:
output_dims = [
input.shape[0] + rng.integers(1, 10),
oh,
ow,
input.shape[3],
]
elif error_name == ErrorIf.ChannelMismatch:
output_dims = [
input.shape[0],
oh,
ow,
input.shape[3] + rng.integers(1, 10),
]
else:
output_dims = [input.shape[0], oh, ow, input.shape[3]]
return serializer.addOutput(output_dims, output_dtype)
@staticmethod
def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
return ser.addOutput(val.shape, out_dtype)
@staticmethod
def transposeConv2DOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, error_name=None
):
h = (ifm.shape[1] - 1) * strides[0] + padding[0] + padding[1] + filter.shape[1]
w = (ifm.shape[2] - 1) * strides[1] + padding[2] + padding[3] + filter.shape[2]
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
if change in [1, 3]:
h = h + rng.choice(choices)
if change in [2, 3]:
w = w + rng.choice(choices)
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
outputs = []
assert ifm1.dtype == ifm2.dtype
input_dtype = ifm1.dtype
if error_name != ErrorIf.FFTInputShapeMismatch:
assert ifm1.shape == ifm2.shape
input_shape = ifm1.shape
if error_name != ErrorIf.WrongRank:
assert len(input_shape) == 3
output_shape = input_shape.copy()
output_dtype = input_dtype
if error_name == ErrorIf.WrongOutputType:
excludes = [DType.FP32]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
output_shape[0] += rng.integers(1, 10)
elif error_name == ErrorIf.FFTOutputShapeMismatch:
modify_dim = rng.choice([1, 2])
output_shape[modify_dim] += rng.integers(1, 10)
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))
return outputs
@staticmethod
def rfft2dOp(serializer, rng, value, error_name=None):
outputs = []
input_shape = value.shape
if error_name != ErrorIf.WrongRank:
assert len(input_shape) == 3
output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
output_dtype = value.dtype
if error_name == ErrorIf.WrongOutputType:
excludes = [DType.FP32]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
output_shape[0] += rng.integers(1, 10)
elif error_name == ErrorIf.FFTOutputShapeMismatch:
modify_dim = rng.choice([1, 2])
output_shape[modify_dim] += rng.integers(1, 10)
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))
return outputs
@staticmethod
def addShapeOp(ser, rng, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
shape = a.shape.copy()
# Do not expect rank 0 tests!
assert len(shape) > 0
if error_name == ErrorIf.DimensionMismatch:
# Can only create this error for rank > 0
fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.SHAPE
return ser.addOutput(shape, outputDType)