blob: 3014c816951442f8ba40f2361572fcd7a6207c71 [file] [log] [blame]
# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import os
from copy import deepcopy
import numpy as np
import serializer.tosa_serializer as ts
from generator.tosa_arg_gen import TosaArgGen
from generator.tosa_arg_gen import TosaQuantGen
from generator.tosa_arg_gen import TosaTensorGen
from generator.tosa_arg_gen import TosaTensorValuesGen
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
from generator.tosa_utils import get_rank_mismatch_shape
from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
from generator.tosa_utils import vect_f32_to_bf16
from tosa.DType import DType
from tosa.Op import Op
class TosaTestGen:
# Maximum rank of tensor supported by test generator.
# This currently matches the 8K level defined in the specification.
TOSA_TENSOR_MAX_RANK = 6
TOSA_8K_LEVEL_MAX_SCALE = 64
TOSA_8K_LEVEL_MAX_KERNEL = 8192
TOSA_8K_LEVEL_MAX_STRIDE = 8192
def __init__(self, args):
self.args = args
self.basePath = args.output_dir
self.random_seed = args.random_seed
self.ser = None
self.rng = np.random.default_rng(self.random_seed)
self.createDynamicOpLists()
self.initOpListDefaults()
self.quantGen = TosaQuantGen()
# Force makeShape to do a specific starting shape
self.targetted_shape = None
# Work out floating point range
self.random_fp_low = min(args.tensor_fp_value_range)
self.random_fp_high = max(args.tensor_fp_value_range)
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
fullPath = os.path.join(self.basePath, self.testPath)
os.makedirs(fullPath, exist_ok=True)
# Embed const data in the flatbuffer
constMode = ts.ConstMode.EMBED
if self.args.dump_consts:
constMode = ts.ConstMode.EMBED_DUMP
self.ser = ts.TosaSerializer(fullPath, constMode)
def getSerializer(self):
return self.ser
def serialize(self, testName):
with open(
os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
) as fd:
fd.write(self.ser.serialize())
with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
fd.write(self.ser.writeJson("{}.tosa".format(testName)))
def resetRNG(self, seed=None):
if seed is None:
seed = self.random_seed + 1
self.rng = np.random.default_rng(seed)
def getRandTensor(self, shape, dtype):
if dtype == DType.BOOL:
return np.bool_(self.rng.choice(a=[False, True], size=shape))
# TOSA specific INT4 weight range from -7 to 7
elif dtype == DType.INT4:
return np.int32(self.rng.integers(low=-7, high=8, size=shape))
elif dtype == DType.INT8:
return np.int32(self.rng.integers(low=-128, high=128, size=shape))
elif dtype == DType.UINT8:
return np.int32(self.rng.integers(low=0, high=256, size=shape))
elif dtype == DType.INT16:
return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
elif dtype == DType.UINT16:
return np.int32(self.rng.integers(low=0, high=65536, size=shape))
elif (
dtype == DType.INT32 or dtype == DType.SHAPE
): # restricting too large value for SHAPE
return np.int32(
self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
)
elif dtype == DType.INT48:
return np.int64(
self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
)
elif dtype == DType.FP16:
return np.float16(
self.rng.uniform(
low=self.random_fp_low, high=self.random_fp_high, size=shape
)
)
elif dtype == DType.BF16:
f32_tensor = np.float32(
self.rng.uniform(
low=self.random_fp_low, high=self.random_fp_high, size=shape
)
)
# Floor the last 16 bits of each f32 value
return np.float32(vect_f32_to_bf16(f32_tensor))
elif dtype == DType.FP32:
return np.float32(
self.rng.uniform(
low=self.random_fp_low, high=self.random_fp_high, size=shape
)
)
else:
raise Exception("Unrecognized Dtype: {}".format(dtype))
def buildPlaceholderTensors(self, shape_list, dtype_list):
placeholders = []
assert len(shape_list) == len(dtype_list)
for idx, shape in enumerate(shape_list):
arr = self.getRandTensor(shape, dtype_list[idx])
placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
return placeholders
def buildConstTensors(self, shape_list, dtype_list):
consts = []
assert len(shape_list) == len(dtype_list)
for idx, shape in enumerate(shape_list):
arr = self.getRandTensor(shape, dtype_list[idx])
consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
return consts
def makeShape(self, rank):
if self.targetted_shape:
return np.int32(self.targetted_shape)
return np.int32(
self.rng.integers(
low=self.args.tensor_shape_range[0],
high=self.args.tensor_shape_range[1],
size=rank,
)
)
def setTargetShape(self, shape):
self.targetted_shape = shape
def randInt(self, low=0, high=256):
return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
def getRandNumberDType(self, dtype):
if dtype == DType.FP32:
return np.float32(
self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
)
elif dtype == DType.FP16:
return np.float16(
self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
)
elif dtype == DType.BF16:
rand_f32 = np.float32(
self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
)
return vect_f32_to_bf16(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
# TOSA specific INT4 weight range from -7 to 7
elif dtype == DType.INT4:
low, high = (-7, 8)
elif dtype == DType.INT8:
low, high = (-128, 128)
elif dtype == DType.INT16:
low, high = (-32768, 32768)
elif (
dtype == DType.INT32 or dtype == DType.SHAPE
): # restricting too large value for SHAPE
low, high = (-(1 << 31), (1 << 31))
elif dtype == DType.INT48:
low, high = (-(1 << 47), (1 << 47))
# Special size
return np.int64(self.rng.integers(low, high, size=1))[0]
else:
raise Exception("Unknown dtype: {}".format(dtype))
return np.int32(self.rng.integers(low, high, size=1))[0]
def shapeStr(self, shape):
sStr = []
# Convert to strings
for i in shape:
sStr.append(str(i))
return "x".join(sStr)
def typeStr(self, dtype):
if isinstance(dtype, list) or isinstance(dtype, tuple):
assert len(dtype) >= 2
strs = [self.typeStr(t) for t in dtype]
# Limit types to the first 2 as the 3rd is the accumulator
return "x".join(strs[:2])
else:
if dtype in DTYPE_ATTRIBUTES:
return DTYPE_ATTRIBUTES[dtype]["str"]
else:
raise Exception(
"Unknown dtype, cannot convert to string: {}".format(dtype)
)
def typeWidth(self, dtype):
"""Get the datatype width for data types"""
if dtype in DTYPE_ATTRIBUTES:
return DTYPE_ATTRIBUTES[dtype]["width"]
else:
raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
def constrictBatchSize(self, shape):
# Limit the batch size unless an explicit target shape set
if self.args.max_batch_size and not self.args.target_shapes:
shape[0] = min(shape[0], self.args.max_batch_size)
return shape
def makeDimension(self):
return self.randInt(
low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
)
# Argument generators
# Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
# Where the string descriptor is used to generate the test name and
# The build_fcn_arg_list is expanded and passed to the operator test
# build function
def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
# build_placeholder returns an int, ABS/other ops does not
if isinstance(op, int):
self.ser.addOperator(op, a.name, result_tens.name, None)
return result_tens
elif op["op"] == Op.IDENTITY:
self.ser.addOperator(op["op"], a.name, result_tens.name, None)
return result_tens
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongOutputType:
if result_tens.dtype not in [DType.INT8, DType.UINT8]:
qinfo = [
TosaQuantGen.getZeroPoint(self, a.dtype),
TosaQuantGen.getZeroPoint(self, result_tens.dtype),
]
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
qinfo=qinfo,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = None
if op["op"] == Op.NEGATE:
attr = ts.TosaSerializerAttribute()
attr.NegateAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
result_tens = OutputShaper.binaryBroadcastOp(
self.ser, self.rng, a, b, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
return result_tens
def build_arithmetic_right_shift(
self, op, a, b, round, validator_fcns=None, error_name=None
):
result_tens = OutputShaper.binaryBroadcastOp(
self.ser, self.rng, a, b, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.ArithmeticRightShiftAttribute(round)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
result_tens = OutputShaper.binaryBroadcastOp(
self.ser, self.rng, a, b, error_name
)
# Special for multiply:
# Force the result to INT32 for INT types
if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
result_tens.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
outputDType = self.rng.choice(all_dtypes)
result_tens.setDtype(outputDType)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_table(self, op, a, table, validator_fcns=None, error_name=None):
result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
attr = ts.TosaSerializerAttribute()
attr.TableAttribute(table)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [cond.name, a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=cond,
input2=a,
input3=b,
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(
op["op"],
input_list,
output_list,
)
return result_tens
def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
result_tens = OutputShaper.binaryComparisonOp(
self.ser, self.rng, a, b, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input1=a,
input2=b,
input_shape=a.shape,
input_dtype=a.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(
op["op"],
input_list,
output_list,
)
return result_tens
def build_argmax(self, op, a, axis, validator_fcns, error_name):
result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
input_dtype=a.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_pool2d(
self,
op,
input,
accum_dtype,
stride,
pad,
kernel,
validator_fcns=None,
error_name=None,
qinfo=None,
):
result_tens = OutputShaper.pool2dOp(
self.ser, self.rng, input, kernel, stride, pad, error_name
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType:
if input.dtype not in [DType.INT8, DType.UINT8]:
qinfo = [
TosaQuantGen.getZeroPoint(self, input.dtype),
TosaQuantGen.getZeroPoint(self, result_tens.dtype),
]
# Invalidate Input/Output list for error if checks.
input_list = [input.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=input.shape,
input_dtype=input.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
kernel=kernel,
stride=stride,
pad=pad,
qinfo=qinfo,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
if qinfo is None:
qinfo = [0, 0]
attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_maxpool2d(
self,
op,
input,
stride,
pad,
kernel,
validator_fcns=None,
error_name=None,
qinfo=None,
):
# Same as build_pool2d but manually sets accum_dtype value
# (maxpool has no accum_dtype)
return self.build_pool2d(
op,
input,
DType.UNKNOWN,
stride,
pad,
kernel,
validator_fcns,
error_name,
qinfo,
)
def build_conv2d(
self,
op,
ifm,
filter,
bias,
accum_dtype,
strides,
padding,
dilations,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(padding) == 4
result_tens = OutputShaper.conv2dOp(
self.ser,
self.rng,
ifm,
filter,
accum_dtype,
strides,
padding,
dilations,
error_name,
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
TosaQuantGen.getZeroPoint(self, result_tens.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tens.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=padding,
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tens.shape,
):
return None
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_conv3d(
self,
op,
ifm,
filter,
bias,
accum_dtype,
strides,
padding,
dilations,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(padding) == 6
result_tens = OutputShaper.conv3dOp(
self.ser,
self.rng,
ifm,
filter,
accum_dtype,
strides,
padding,
dilations,
error_name,
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
TosaQuantGen.getZeroPoint(self, result_tens.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tens.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=padding,
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tens.shape,
):
return None
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_transpose_conv2d(
self,
op,
ifm,
filter,
bias,
accum_dtype,
stride,
out_pad,
output_shape,
validator_fcns=None,
error_name=None,
qinfo=None,
):
assert len(out_pad) == 4
result_tens = OutputShaper.transposeConv2DOp(
self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
TosaQuantGen.getZeroPoint(self, result_tens.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tens.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=out_pad,
stride=stride,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tens.shape,
):
return None
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_depthwise_conv2d(
self,
op,
ifm,
filter,
bias,
accum_dtype,
strides,
padding,
dilations,
validator_fcns=None,
error_name=None,
qinfo=None,
):
result_tens = OutputShaper.depthwiseConv2dOp(
self.ser,
self.rng,
ifm,
filter,
accum_dtype,
strides,
padding,
dilations,
error_name,
)
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
DType.INT8,
DType.UINT8,
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
TosaQuantGen.getZeroPoint(self, result_tens.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_dtype=result_tens.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=padding,
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tens.shape,
):
return None
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_fully_connected(
self,
op,
ifm,
filter,
bias,
accum_dtype,
validator_fcns=None,
error_name=None,
qinfo=None,
):
result_tens = OutputShaper.fullyConnectedOp(
self.ser, self.rng, ifm, filter, accum_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=ifm.shape,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
qinfo=qinfo,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
accum_dtype=accum_dtype,
):
return None
attr = ts.TosaSerializerAttribute()
attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_matmul(
self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
):
result_tens = OutputShaper.matmulOp(
self.ser, self.rng, a, b, accum_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
input_dtype=a.dtype,
input2_shape=b.shape,
input2_dtype=b.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
qinfo=qinfo,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
accum_dtype=accum_dtype,
):
return None
attr = ts.TosaSerializerAttribute()
attr.MatMulAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_clamp(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
if error_name == ErrorIf.MaxSmallerMin:
# Make sure the numbers are different to invoke this error
while v[0] == v[1]:
v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
max_val = min(v)
min_val = max(v)
else:
max_val = max(v)
min_val = min(v)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
max_val=max_val,
min_val=min_val,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
if a.dtype == DType.FP16:
# Non-tensor fp16 ops take fp16 values as fp32 in reference_model
min_val = min_val.astype(np.float32)
max_val = max_val.astype(np.float32)
attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
else:
attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
attr = ts.TosaSerializerAttribute()
attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
return result_tens
# Needs an additional type/input
def build_prelu(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
self.ser.addOperator(op["op"], [a.name], [result_tens.name])
return result_tens
def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_tanh(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_erf(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_concat(self, op, *a, validator_fcns=None, error_name=None):
if error_name != ErrorIf.WrongInputType:
assert type(a[-1]) == int
# To store variable length list of input tensors we need to store axis along with it
axis = a[-1]
a = a[:-1]
result_tens = OutputShaper.concatOp(
self.ser, self.rng, axis, *a, error_name=error_name
)
input_tensor_names = []
for tensor in a:
input_tensor_names.append(tensor.name)
# Invalidate Input/Output list for error if checks.
input_list = input_tensor_names
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a[0].shape,
output_shape=result_tens.shape,
input_dtype=a[0].dtype,
output_dtype=result_tens.dtype,
inputs=a,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_pad(
self,
op,
a,
padding,
pad_const_int,
pad_const_float,
validator_fcns=None,
error_name=None,
qinfo=None,
):
result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
attr = ts.TosaSerializerAttribute()
attr.PadAttribute(
self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
pad=padding,
qinfo=qinfo,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_dim(
self,
op,
a,
axis,
validator_fcns=None,
error_name=None,
qinfo=None,
):
result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
input_dtype=a.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
result_tens = OutputShaper.reshapeOp(
self.ser, self.rng, a, newShape, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.ReshapeAttribute(newShape)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
axis=axis,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
attr = ts.TosaSerializerAttribute()
attr.TransposeAttribute(perms)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
perms=perms,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
result_tens = OutputShaper.sliceOp(
self.ser, self.rng, a, start, size, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
start=start,
size=size,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
attr = ts.TosaSerializerAttribute()
attr.SliceAttribute(start, size)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=a.shape,
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
input1=a,
):
return None
attr = ts.TosaSerializerAttribute()
attr.TileAttribute(multiples)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_gather(self, op, values, validator_fcns=None, error_name=None):
# Create a new indicies tensor
# here with data that doesn't exceed the dimensions of the values tensor
K = values.shape[1] # K
W = self.randInt(
self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
) # W
indicies_arr = np.int32(
self.rng.integers(low=0, high=K, size=[values.shape[0], W])
) # (N, W)
indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
result_tens = OutputShaper.gatherOp(
self.ser, self.rng, values, indicies, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [values.name, indicies.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=values.shape,
output_shape=result_tens.shape,
input_dtype=values.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
# Create a new indicies tensor
# here with data that doesn't exceed the dimensions of the values_in tensor
K = values_in.shape[1] # K
W = input.shape[1] # W
indicies_arr = np.int32(
self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
) # (N, W)
indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
result_tens = OutputShaper.scatterOp(
self.ser, self.rng, values_in, indicies, input, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [values_in.name, indicies.name, input.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=values_in.shape,
output_shape=result_tens.shape,
input_dtype=values_in.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_resize(
self,
op,
input,
mode,
scale,
offset,
border,
input_dtype,
output_dtype,
validator_fcns,
error_name=None,
):
result_tens = OutputShaper.resizeOp(
self.ser,
self.rng,
input,
mode,
scale,
offset,
border,
input_dtype,
output_dtype,
error_name,
)
# Invalidate Input/Output list for error if checks.
input_list = [input.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
mode=mode,
scale=scale,
input_dtype=input_dtype,
output_dtype=output_dtype,
input_shape=input.shape,
output_shape=result_tens.shape,
offset=offset,
border=border,
input_list=input_list,
output_list=output_list,
result_tensors=[result_tens],
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(scale, offset, border, mode)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
self.ser.addOperator(
op, [val.name, val2.name], [result_tens.name, result_tens2.name]
)
return result_tens
def build_const(self, op, val, validator_fcns=None, error_name=None):
self.ser.addOutputTensor(val)
return val
# Type Conversion
def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
result_tens = OutputShaper.typeConversionOp(
self.ser, self.rng, val, out_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [val.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=val.shape,
output_shape=result_tens.shape,
input_dtype=val.dtype,
output_dtype=result_tens.dtype,
result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_rescale(
self,
op,
val,
out_dtype,
scale32,
double_round,
per_channel,
validator_fcns,
error_name,
):
result_tens = OutputShaper.typeConversionOp(
self.ser, self.rng, val, out_dtype, error_name
)
if per_channel:
nc = val.shape[-1]
else:
nc = 1
in_type_width = self.typeWidth(val.dtype)
out_type_width = self.typeWidth(out_dtype)
if val.dtype == DType.INT8:
input_zp = self.randInt(-128, 128)
in_type_width += 1
elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
in_type_width += 1
elif error_name in [
ErrorIf.InputZeroPointNotZero,
ErrorIf.U16InputZeroPointNotValid,
]:
input_zp = self.randInt(-128, 128)
if input_zp == 0:
input_zp = input_zp + self.rng.integers(1, 10)
in_type_width += 1
elif val.dtype == DType.UINT16:
# Must come after ErrorIf.U16InputZeroPointNotValid check
input_zp = self.rng.choice([0, 32768])
in_type_width += 1
else:
input_zp = 0
if out_dtype == DType.INT8:
output_zp = self.randInt(-128, 128)
out_type_width += 1
elif out_dtype == DType.UINT8:
output_zp = self.randInt(0, 256)
out_type_width += 1
elif error_name in [
ErrorIf.OutputZeroPointNotZero,
ErrorIf.U16OutputZeroPointNotValid,
]:
output_zp = self.randInt(-128, 128)
if output_zp == 0:
output_zp = output_zp + self.rng.integers(1, 10)
out_type_width += 1
elif out_dtype == DType.UINT16:
# Must come after ErrorIf.U16OutputZeroPointNotValid check
output_zp = self.rng.choice([0, 32768])
out_type_width += 1
else:
output_zp = 0
# Calculate scale based on:
# scale = a *(2^output_width)/(2^input_width))
a = np.float32(self.rng.random(size=[nc]))
scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
if scale32:
pass
# Cap the scaling at 2^31 - 1 for scale32
scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
else:
# Cap the scaling at 2^15 - 1 for scale16
scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
# print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
multiplier_arr = np.int32(np.zeros(shape=[nc]))
shift_arr = np.int32(np.zeros(shape=[nc]))
min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
for i in range(nc):
multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
scale_arr[i], scale32
)
min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
if scale32 and error_name is None:
# Make sure random values are within apply_scale_32 specification
# REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
assert val.placeholderFilename
values = np.load(
os.path.join(self.basePath, self.testPath, val.placeholderFilename)
)
val_adj = np.subtract(values, input_zp, dtype=np.int64)
val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
if not np.all(np.array_equal(values, val_adj)):
# Values changed so overwrite file with new values
np.save(
os.path.join(self.basePath, self.testPath, val.placeholderFilename),
val_adj,
False,
)
# Invalidate Input/Output list for error if checks.
input_list = [val.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
)
qinfo = (input_zp, output_zp)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_dtype=val.dtype,
output_dtype=out_dtype,
input_shape=val.shape,
qinfo=qinfo,
scale32=scale32,
double_round=double_round,
input_list=input_list,
output_list=output_list,
result_tensors=[result_tens],
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.RescaleAttribute(
input_zp,
output_zp,
multiplier_arr,
shift_arr,
scale32,
double_round,
per_channel,
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def _get_condition_tensor(self, op, cond, error_name):
if error_name == ErrorIf.CondIfCondNotMatchingBool:
cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
else:
cond_type = DType.BOOL
if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
choice = self.rng.choice([1, 2])
if choice == 1:
cond_shape = [2]
else:
cond_shape = [1, 2]
else:
# Must be of size 1 (rank 0)
cond_shape = []
cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
return cond_tens
def build_cond_if_const(
self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
):
# For cond_if with constants, we're supplied with then/else tensors that we ignore
# (except for the generated shape) and the condition. Build Then/Else blocks
# and fill them with const nodes for the body.
# Condition tensor
cond_tens = self._get_condition_tensor(op, cond, error_name)
# Make then/else tensors
out_shape = then_tens.shape
# Create an incorrect output shape for error_if tests
if error_name in [
ErrorIf.CondIfOutputListThenGraphMismatch,
ErrorIf.CondIfOutputListElseGraphMismatch,
]:
incorrect_shape = deepcopy(then_tens.shape)
for i in range(len(incorrect_shape)):
incorrect_shape[i] += (
self.rng.choice([-3, -2, 2, 3])
if incorrect_shape[i] > 3
else self.rng.choice([1, 2, 4])
)
incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
# And the result tensor based on any of the outputs
result_tens = self.ser.addOutput(out_shape, DType.INT32)
# Create the attribute with the names of the then/else blocks
then_block = "THEN_BLOCK"
else_block = "ELSE_BLOCK"
attr = ts.TosaSerializerAttribute()
attr.CondIfAttribute(then_block, else_block)
# Finally, build the op and the two blocks
self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
else:
then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
self.ser.addOutputTensor(then_tens)
self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
else:
else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
self.ser.addOutputTensor(else_tens)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
return result_tens
def build_cond_if_binary(
self, op, a, b, cond, validator_fcns=None, error_name=None
):
# For cond_if with a binary op in the then/else blocks, take a and b and
# alternately add or subtract them based on the condition
# Condition tensor
cond_tens = self._get_condition_tensor(op, cond, error_name)
result_tens = self.ser.addOutput(a.shape, a.dtype)
# Create the attribute with the names of the then/else blocks
then_block = "THEN_BLOCK"
else_block = "ELSE_BLOCK"
attr = ts.TosaSerializerAttribute()
attr.CondIfAttribute(then_block, else_block)
if error_name in [
ErrorIf.CondIfInputListThenGraphMismatch,
ErrorIf.CondIfInputListElseGraphMismatch,
ErrorIf.CondIfOutputListElseGraphMismatch,
ErrorIf.CondIfOutputListThenGraphMismatch,
]:
incorrect_shape = a.shape.copy()
for i in range(len(incorrect_shape)):
incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
incorrect_block_input = deepcopy(a)
incorrect_block_input.shape = incorrect_shape
# Finally, build the op and the two blocks
self.ser.addOperator(
op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
then_op, else_op = Op.ADD, Op.SUB
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
else:
assert False, f"No tests for DType: {a.dtype}"
for block, op in ((then_block, then_op), (else_block, else_op)):
self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
and block == then_block
) or (
error_name == ErrorIf.CondIfInputListElseGraphMismatch
and block == else_block
):
self.ser.addInputTensor(incorrect_block_input)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(a.shape, a.dtype)
elif (
error_name == ErrorIf.CondIfOutputListThenGraphMismatch
and block == then_block
) or (
error_name == ErrorIf.CondIfOutputListElseGraphMismatch
and block == else_block
):
self.ser.addInputTensor(a)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
else:
self.ser.addInputTensor(a)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(a.shape, a.dtype)
self.ser.addOperator(op, [a.name, b.name], [tens.name])
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
a=a,
b=b,
basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
return result_tens
def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
cond_block = "COND_BLOCK"
body_block = "BODY_BLOCK"
attr = ts.TosaSerializerAttribute()
attr.WhileLoopAttribute(cond_block, body_block)
# Accumulator tensor
# acc = self.ser.addOutput(a.shape, a.dtype)
acc_init_val = np.int32(np.zeros(a.shape))
acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
# Intermediate/output tensors for everything going through the loop
iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
a_out = self.ser.addIntermediate(a.shape, a.dtype)
if error_name == ErrorIf.InputListOutputListMismatch:
incorrect_acc = deepcopy(acc)
for i in range(len(incorrect_acc.shape)):
incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
else:
acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
# While_loop operator
self.ser.addOperator(
op["op"],
[iter.name, a.name, acc.name],
[iter_out.name, a_out.name, acc_out.name],
attr,
)
self.ser.addOutputTensor(acc_out)
if error_name in [
ErrorIf.InputListCondGraphMismatch,
ErrorIf.InputListBodyGraphInputMismatch,
ErrorIf.InputListBodyGraphOutputMismatch,
]:
incorrect_iter = deepcopy(iter)
for i in range(len(incorrect_iter.shape)):
incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
if len(incorrect_iter.shape) == 0:
incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
incorrect_acc = deepcopy(acc)
for i in range(len(incorrect_acc.shape)):
incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
# COND block (input: iter, output: cond_tens )
self.ser.addBasicBlock(cond_block)
if error_name == ErrorIf.InputListCondGraphMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(incorrect_acc)
else:
self.ser.addInputTensor(iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(acc)
zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
else:
cond_type = DType.BOOL
if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
choice = self.rng.choice([1, 2])
if choice == 1:
cond_shape = [3]
else:
cond_shape = [1, 2]
else:
cond_shape = []
cond_tens = self.ser.addOutput(cond_shape, cond_type)
self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
self.ser.addBasicBlock(body_block)
if error_name == ErrorIf.InputListBodyGraphInputMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(incorrect_acc)
else:
self.ser.addInputTensor(iter)
self.ser.addInputTensor(a)
self.ser.addInputTensor(acc)
one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
iter_body_out = self.ser.addIntermediate(
incorrect_iter.shape, incorrect_iter.dtype
)
acc_body_out = self.ser.addIntermediate(
incorrect_acc.shape, incorrect_acc.dtype
)
else:
iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
self.ser.addOutputTensor(iter_body_out)
self.ser.addOutputTensor(a)
self.ser.addOutputTensor(acc_body_out)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
basicBlocks=self.ser.currRegion.basicBlocks,
):
return None
return acc_out
def build_fft2d(
self, op, val1, val2, inverse, validator_fcns=None, error_name=None
):
results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
input_names = [val1.name, val2.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
output_names = [res.name for res in results]
output_shapes = [res.shape for res in results]
output_dtypes = [res.dtype for res in results]
input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_names, output_names
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
inverse=inverse,
input1=val1,
input2=val2,
input_shape=val1.shape,
input_dtype=val1.dtype,
output_shape=output_shapes,
output_dtype=output_dtypes,
result_tensors=results,
input_list=input_names,
output_list=output_names,
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
attr.FFTAttribute(inverse)
self.ser.addOperator(op["op"], input_names, output_names, attr)
return results
def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
input_names = [val.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
output_names = [res.name for res in results]
output_shapes = [res.shape for res in results]
output_dtypes = [res.dtype for res in results]
input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_names, output_names
)
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
input_shape=val.shape,
input_dtype=val.dtype,
output_shape=output_shapes,
output_dtype=output_dtypes,
result_tensors=results,
input_list=input_names,
output_list=output_names,
num_operands=num_operands,
):
return None
self.ser.addOperator(op["op"], input_names, output_names)
return results
def create_filter_lists(
self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
):
# Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
default_test_rank_range = range(1, 5)
if not shapeFilter:
shapeFilter = [None]
# Calculate the filters based on what is requested and what the operator allows
rmin, rmax = op["rank"]
if rankFilter is not None:
cleanRankFilter = []
# Ensure rankFilter values are allowed by operator
for rank in rankFilter:
if rank >= rmin and rank <= rmax:
cleanRankFilter.append(rank)
elif rankFilter is None and shapeFilter[0] is None:
# Ensure default behaviour is bounded by default range or by operator,
# whichever is the smaller range of ranks.
opRankRange = range(rmin, rmax + 1)
cleanRankFilter = (
opRankRange
if len(opRankRange) <= len(default_test_rank_range)
else default_test_rank_range
)
else:
cleanRankFilter = range(rmin, rmax + 1)
dtypes = op["types"]
if dtypeFilter is not None:
cleanDtypeFilter = []
# Create list of operator dtypes filtered by requested dtypes
for dtype in dtypes:
if dtype in dtypeFilter or (
isinstance(dtype, list) and dtype[0] in dtypeFilter
):
cleanDtypeFilter.append(dtype)
else:
cleanDtypeFilter = dtypes
if testType == "positive":
filterDict = {
"shapeFilter": shapeFilter,
"rankFilter": cleanRankFilter,
"dtypeFilter": cleanDtypeFilter,
}
return filterDict
elif testType == "negative":
if validator is not None:
validator_info = validator(check=False, op=op)
else:
return None
error_arguments = validator_info["param_reqs"]
# Set parameters as required
if error_arguments["rank"] is not None:
rankFilter = error_arguments["rank"]
else:
rankFilter = cleanRankFilter
if error_arguments["dtype"] is not None:
dtypeFilter = error_arguments["dtype"]
else:
dtypeFilter = cleanDtypeFilter
if error_arguments["shape"] is not None:
shapeFilter = error_arguments["shape"]
else:
shapeFilter = shapeFilter[
:2
] # Reduce number of shapes to keep test numbers small
filterDict = {
"shapeFilter": shapeFilter,
"rankFilter": rankFilter,
"dtypeFilter": dtypeFilter,
}
return filterDict
def genOpTestList(
self,
opName,
shapeFilter=[None],
rankFilter=None,
dtypeFilter=None,
testType="positive",
):
try:
op = self.TOSA_OP_LIST[opName]
except KeyError:
raise Exception("Cannot find op with name {}".format(opName))
# Initialize a new random number generator
self.rng = np.random.default_rng(self.random_seed)
build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
# Test list consists of a tuple of:
# (opName, testNameStr, dtype, shapeList, argumentsList)
testList = []
if testType == "negative" and "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
else:
error_if_validators = [None]
for validator in error_if_validators:
if validator is not None:
error_name = validator(check=False, op=op)["error_name"]
else:
error_name = None
filterDict = self.create_filter_lists(
op, shapeFilter, rankFilter, dtypeFilter, testType, validator
)
if filterDict is None:
return []
cleanRankFilter = filterDict["rankFilter"]
cleanDtypeFilter = filterDict["dtypeFilter"]
cleanShapeFilter = filterDict["shapeFilter"]
# print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
for r in cleanRankFilter:
for t in cleanDtypeFilter:
for shape in cleanShapeFilter:
# Filter out by rank
if shape is not None and len(shape) != r:
continue
self.setTargetShape(shape)
shapeList = tgen_fcn(self, op, r, error_name)
shapeStr = self.shapeStr(shapeList[0])
typeStr = self.typeStr(t)
# Argument lists consists of tuples of the (str, []) string representation and the build function argument list
argList = []
if agen_fcn:
argList = agen_fcn(self, opName, shapeList, t, error_name)
else:
argList = [("", [])]
for argStr, args in argList:
if testType == "positive":
if argStr:
testStr = "{}_{}_{}_{}".format(
opName, shapeStr, typeStr, argStr
)
else:
testStr = "{}_{}_{}".format(
opName, shapeStr, typeStr
)
elif testType == "negative":
if argStr:
testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
opName, error_name, shapeStr, typeStr, argStr
)
else:
testStr = "{}_ERRORIF_{}_{}_{}".format(
opName, error_name, shapeStr, typeStr
)
testList.append(
(opName, testStr, t, error_name, shapeList, args)
)
if testType == "positive":
# Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
if "invalid_test_validators" in op:
invalid_test_validators = op["invalid_test_validators"]
clean_testList = []
for test in testList:
remove_test = False
for validator_fcn in invalid_test_validators:
if validator_fcn(
opName=test[0],
input_dtype=test[2],
shapeList=test[4],
args=test[5],
):
remove_test = True
if not remove_test:
clean_testList.append(test)
testList = clean_testList
return testList
def serializeTest(
self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
):
try:
op = self.TOSA_OP_LIST[opName]
except KeyError:
raise Exception("Cannot find op with name {}".format(opName))
if self.args.verbose:
print(f"Creating {testStr}")
# Create a serializer
self.createSerializer(opName, testStr)
build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
if "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
else:
error_if_validators = None
pCount, cCount = op["operands"]
num_operands = pCount + cCount
if isinstance(dtype_or_dtypeList, list):
dtypeList = dtype_or_dtypeList
elif op["op"] == Op.CONCAT:
dtypeList = [dtype_or_dtypeList] * len(shapeList)
else:
dtypeList = [dtype_or_dtypeList] * (num_operands)
if op["op"] != Op.CONCAT:
assert (
len(shapeList) == num_operands
), "shapeList length {} must match number of operands {}".format(
len(shapeList), num_operands
)
assert (
len(dtypeList) == num_operands
), "dtypeList length {} must match number of operands {}".format(
len(dtypeList), num_operands
)
try:
qgen = op["qgen"]
except KeyError:
qgen = None
# Build the random tensor operands and the test
tens = []
if qgen is not None:
qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
else:
qinfo = None
tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
try:
if error_if_validators is None:
if qinfo is not None:
resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
else:
resultName = build_fcn(self, op, *tens, *testArgs)
else:
if qinfo is not None:
resultName = build_fcn(
self,
op,
*tens,
*testArgs,
validator_fcns=error_if_validators,
error_name=error_name,
qinfo=qinfo,
)
else:
resultName = build_fcn(
self,
op,
*tens,
*testArgs,
validator_fcns=error_if_validators,
error_name=error_name,
)
except TypeError as e:
print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
raise e
if resultName:
# The test is valid, serialize it
self.serialize("test")
else:
# The test is not valid
print(f"Invalid ERROR_IF test created: {opName} {testStr}")
def createDynamicOpLists(self):
if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
# Already created these lists (can occur when class is initialized more than once)
return
# Dynamically create op lists for convolutions with a list of kernel sizes
if not self.args.level8k:
KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
else:
bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
KERNELS_2D = [[1, bigK], [bigK, 2]]
KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
for k in KERNELS_2D:
testName = "conv2d_{}x{}".format(k[0], k[1])
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
"depthwise_conv2d_TEMPLATE"
].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
"transpose_conv2d_TEMPLATE"
].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
for k in KERNELS_3D:
testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
# Delete any templates after having created any dynamic ops
# This is a two-pass operation because it's bad practice to delete
# keys from dictionaries while iterating
keyList = []
for k in self.TOSA_OP_LIST:
try:
if self.TOSA_OP_LIST[k]["template"]:
keyList.append(k)
continue
except KeyError:
pass
for k in keyList:
del self.TOSA_OP_LIST[k]
def initOpListDefaults(self):
"""Fill in default fields for ops if they aren't already specified.
Look for missing required fields (datastructure linting)."""
for op in self.TOSA_OP_LIST:
# Required fields
try:
pl, c = self.TOSA_OP_LIST[op]["operands"]
except (KeyError, ValueError, TypeError):
raise Exception(
"Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
)
try:
fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
except (KeyError, ValueError, TypeError):
raise Exception(
"Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
op
)
)
try:
_ = self.TOSA_OP_LIST[op]["types"]
except KeyError:
raise Exception(
"Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
)
try:
_ = self.TOSA_OP_LIST[op]["op"]
except KeyError:
raise Exception(
"Op {} is missing the Op field in TOSA_OP_LIST".format(op)
)
# Put in default rank range, if missing
try:
_ = self.TOSA_OP_LIST[op]["rank"]
except KeyError:
self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
# Tensor operator list
# 'op': op name
# 'operands': tuple of (placeholder, const) operands
# 'rank': optional, restricts rank to tuple inclusive of (min, max),
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
TYPE_INT_FP = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.FP16,
DType.BF16,
DType.FP32,
] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
TYPE_FI32 = [
DType.FP32,
DType.FP16,
DType.BF16,
DType.INT32,
] # floating-types and INT32
TYPE_FIB = [
DType.FP16,
DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
DType.BOOL,
]
TYPE_FI16 = [DType.FP32, DType.INT16]
TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
[DType.INT8, DType.INT4, DType.INT32],
[DType.INT8, DType.INT8, DType.INT32],
[DType.INT16, DType.INT8, DType.INT48],
[DType.FP16, DType.FP16, DType.FP16],
[DType.FP16, DType.FP16, DType.FP32],
[DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
]
DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
TOSA_OP_LIST = {
# Tensor operators
"argmax": {
"op": Op.ARGMAX,
"operands": (1, 0),
"rank": (1, 6),
"build_fcn": (
build_argmax,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_NARROW_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evArgmaxOutputRankMismatch,
TosaErrorValidator.evArgmaxOutputShapeMismatch,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"avg_pool2d": {
"op": Op.AVG_POOL2D,
"operands": (1, 0),
"rank": (4, 4),
"build_fcn": (
build_pool2d,
TosaTensorGen.tgNHWC,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agPooling,
),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_NARROW_INT_FP,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evPoolingOutputShapeMismatch,
TosaErrorValidator.evPoolingOutputShapeNonInteger,
),
},
# Templated operator. Filled in by createDynamicOpLists
"conv2d_TEMPLATE": {
"op": Op.CONV2D,
"operands": (1, 2),
"rank": (4, 4),
"build_fcn": (
build_conv2d,
TosaTensorGen.tgConv2D,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
"conv3d_TEMPLATE": {
"op": Op.CONV3D,
"operands": (1, 2),
"rank": (5, 5),
"build_fcn": (
build_conv3d,
TosaTensorGen.tgConv3D,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
"depthwise_conv2d_TEMPLATE": {
"op": Op.DEPTHWISE_CONV2D,
"operands": (1, 2),
"filter": [1, 1],
"rank": (4, 4),
"build_fcn": (
build_depthwise_conv2d,
TosaTensorGen.tgDepthwiseConv2D,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
"template": True,
},
"fully_connected": {
"op": Op.FULLY_CONNECTED,
"operands": (1, 2),
"rank": (2, 2),
"build_fcn": (
build_fully_connected,
TosaTensorGen.tgFullyConnected,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agFullyConnected,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"matmul": {
"op": Op.MATMUL,
"operands": (2, 0),
"rank": (3, 3),
"build_fcn": (
build_matmul,
TosaTensorGen.tgMatmul,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP,
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"max_pool2d": {
"op": Op.MAX_POOL2D,
"operands": (1, 0),
"rank": (4, 4),
"build_fcn": (
build_maxpool2d,
TosaTensorGen.tgNHWC,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agPooling,
),
"types": TYPE_NARROW_INT_FP,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evPoolingOutputShapeMismatch,
TosaErrorValidator.evPoolingOutputShapeNonInteger,
),
},
# Templated operator. Filled in by createDynamicOpLists
"transpose_conv2d_TEMPLATE": {
"op": Op.TRANSPOSE_CONV2D,
"operands": (1, 2),
"rank": (4, 4),
"build_fcn": (
build_transpose_conv2d,
TosaTensorGen.tgTransposeConv2D,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agTransposeConv2D,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (
TosaInvalidValidator.ivHeightWidthInvalid,
TosaInvalidValidator.ivNonPositiveOutputShape,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
),
"template": True,
},
# Activation functions
"clamp": {
"op": Op.CLAMP,
"operands": (1, 0),
"build_fcn": (
build_clamp,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_NARROW_INT_FP,
"error_if_validators": (
TosaErrorValidator.evMaxSmallerMin,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"sigmoid": {
"op": Op.SIGMOID,
"operands": (1, 0),
"build_fcn": (
build_sigmoid,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"tanh": {
"op": Op.TANH,
"operands": (1, 0),
"build_fcn": (
build_tanh,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"erf": {
"op": Op.ERF,
"operands": (1, 0),
"build_fcn": (
build_erf,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
# Elementwise Binary Operators
"add": {
"op": Op.ADD,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgAddSub,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"arithmetic_right_shift": {
"op": Op.ARITHMETIC_RIGHT_SHIFT,
"operands": (2, 0),
"build_fcn": (
build_arithmetic_right_shift,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgArithmeticRightShift,
TosaArgGen.agArithmeticRightShift,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_and": {
"op": Op.BITWISE_AND,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_or": {
"op": Op.BITWISE_OR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_xor": {
"op": Op.BITWISE_XOR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"intdiv": {
"op": Op.INTDIV,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgIntDiv,
None,
),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_and": {
"op": Op.LOGICAL_AND,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_left_shift": {
"op": Op.LOGICAL_LEFT_SHIFT,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLogicalShift,
None,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_right_shift": {
"op": Op.LOGICAL_RIGHT_SHIFT,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgLogicalShift,
None,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_or": {
"op": Op.LOGICAL_OR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_xor": {
"op": Op.LOGICAL_XOR,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"maximum": {
"op": Op.MAXIMUM,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"minimum": {
"op": Op.MINIMUM,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"mul": {
"op": Op.MUL,
"operands": (2, 0),
"build_fcn": (
build_mul,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgMul,
TosaArgGen.agMul,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"pow": {
"op": Op.POW,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"sub": {
"op": Op.SUB,
"operands": (2, 0),
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgAddSub,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"table": {
"op": Op.TABLE,
# Use the automatic generation functions to create the input array
# but create the table tensor in the build function, as it may be
# a different type from the input
"operands": (1, 0),
"build_fcn": (
build_table,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agTable,
),
"types": [DType.INT8, DType.INT16],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
# Elementwise Unary operators
"abs": {
"op": Op.ABS,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"bitwise_not": {
"op": Op.BITWISE_NOT,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"ceil": {
"op": Op.CEIL,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"clz": {
"op": Op.CLZ,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"exp": {
"op": Op.EXP,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"floor": {
"op": Op.FLOOR,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"log": {
"op": Op.LOG,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"logical_not": {
"op": Op.LOGICAL_NOT,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"negate": {
"op": Op.NEGATE,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgNegate,
None,
),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reciprocal": {
"op": Op.RECIPROCAL,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"rsqrt": {
"op": Op.RSQRT,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
# Elementwise Ternary operators
"select": {
"op": Op.SELECT,
"operands": (3, 0),
"build_fcn": (
build_select,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgSelect,
None,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Comparison operators
"equal": {
"op": Op.EQUAL,
"operands": (2, 0),
"build_fcn": (
build_comparison,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgEqual,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater_equal": {
"op": Op.GREATER_EQUAL,
"operands": (2, 0),
"build_fcn": (
build_comparison,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater": {
"op": Op.GREATER,
"operands": (2, 0),
"build_fcn": (
build_comparison,
TosaTensorGen.tgBroadcastFuzz,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Reduction operators
"reduce_all": {
"op": Op.REDUCE_ALL,
"operands": (1, 0),
"rank": (1, 4),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reduce_any": {
"op": Op.REDUCE_ANY,
"operands": (1, 0),
"rank": (1, 4),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reduce_max": {
"op": Op.REDUCE_MAX,
"operands": (1, 0),
"rank": (1, 4),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reduce_min": {
"op": Op.REDUCE_MIN,
"operands": (1, 0),
"rank": (1, 4),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reduce_product": {
"op": Op.REDUCE_PRODUCT,
"operands": (1, 0),
"rank": (1, 4),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"reduce_sum": {
"op": Op.REDUCE_SUM,
"operands": (1, 0),
"rank": (1, 4),
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
"types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evShapeOfAxisNotOne,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
# Data layout operators
"concat": {
"op": Op.CONCAT,
"operands": (2, 0),
"build_fcn": (
build_concat,
TosaTensorGen.tgConcat,
TosaTensorValuesGen.tvgConcat,
TosaArgGen.agAxis,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evConcatInputRankMismatch,
TosaErrorValidator.evConcatShapeSumMismatch,
TosaErrorValidator.evConcatInputDimMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongOutputList,
),
},
"pad": {
"op": Op.PAD,
"operands": (1, 0),
"build_fcn": (
build_pad,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agPad,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evPadOutputShapeMismatch,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongRank,
),
},
"dim": {
"op": Op.DIM,
"operands": (1, 0),
"build_fcn": (
build_dim,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
},
"reshape": {
"op": Op.RESHAPE,
"operands": (1, 0),
"build_fcn": (
build_reshape,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agReshape,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evTensorSizeInputOutputMismatch,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evReshapeOutputSizeMultiInference,
TosaErrorValidator.evReshapeOutputSizeNonInteger,
),
},
"reverse": {
"op": Op.REVERSE,
"operands": (1, 0),
"build_fcn": (
build_reverse,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"slice": {
"op": Op.SLICE,
"operands": (1, 0),
"rank": (1, 6),
"build_fcn": (
build_slice,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agSlice,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evStartSmallerZero,
TosaErrorValidator.evSizeSmallerEqualZero,
TosaErrorValidator.evStartSizeOutsideBounds,
TosaErrorValidator.evSizeOutputShapeMismatch,
TosaErrorValidator.evInputSizeStartLengthMismatch,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
),
},
"tile": {
"op": Op.TILE,
"operands": (1, 0),
"rank": (1, 6),
"build_fcn": (
build_tile,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agTile,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongRank,
),
},
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
"rank": (1, 6),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agTranspose,
),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evIndexOutsideBounds,
TosaErrorValidator.evIndexUsedTwice,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evTensorSizeInputOutputMismatch,
),
},
# Data nodes
"const": {
"op": Op.CONST,
"operands": (0, 1),
"build_fcn": (
build_const,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FIB + [DType.INT48],
},
"identity": {
"op": Op.IDENTITY,
"operands": (1, 0),
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_FIB,
},
# Scatter/Gather
"gather": {
"op": Op.GATHER,
# Only specify 'values' tensor here. 'indices' is generated in op building stage
"operands": (1, 0),
"rank": (3, 3),
"build_fcn": (
build_gather,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": (
DType.INT8,
DType.INT16,
DType.INT32,
DType.FP16,
DType.BF16,
DType.FP32,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
},
"scatter": {
"op": Op.SCATTER,
# Only specify 'values_in' tensor here.
# 'indices' and 'input' are generated in op building stage
"operands": (2, 0),
"rank": (3, 3),
"build_fcn": (
build_scatter,
TosaTensorGen.tgScatter,
TosaTensorValuesGen.tvgDefault,
None,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
),
},
# Image operations
"resize": {
"op": Op.RESIZE,
"operands": (1, 0),
"rank": (4, 4),
"build_fcn": (
build_resize,
TosaTensorGen.tgNHWC,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agResize,
),
"types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
"error_if_validators": (
TosaErrorValidator.evMaxDimExceeded,
TosaErrorValidator.evScaleSmallerEqualZero,
TosaErrorValidator.evScaleNLargerMax,
TosaErrorValidator.evScaleDLargerMax,
TosaErrorValidator.evOffsetSmallerMin,
TosaErrorValidator.evOffsetLargerEqualMax,
TosaErrorValidator.evBorderSmallerMin,
TosaErrorValidator.evBorderLargerEqualMax,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evChannelMismatch,
TosaErrorValidator.evResizeOutputShapeMismatch,
TosaErrorValidator.evResizeOutputShapeNonInteger,
),
},
# Type conversion
"cast": {
"op": Op.CAST,
"operands": (1, 0),
"build_fcn": (
build_cast,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agCast,
),
"types": (
DType.FP16,
DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
DType.BOOL,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
"rescale": {
"op": Op.RESCALE,
"operands": (1, 0),
"build_fcn": (
build_rescale,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agRescale,
),
"types": [
DType.UINT8,
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.UINT16,
],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
TosaErrorValidator.evU16InputZeroPointNotValid,
TosaErrorValidator.evU16OutputZeroPointNotValid,
TosaErrorValidator.evScaleTrue,
TosaErrorValidator.evScaleNotTrue,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
},
# Custom
# Not implemented.
# Control flow operators
# Two varients of cond_if, one that generates one of two constant tensors (no
# inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
# (two inputs to the basic blocks, one output)
"cond_if_const": {
"op": Op.COND_IF,
"operands": (0, 2),
"build_fcn": (
build_cond_if_const,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agCondIf,
),
"types": [DType.BOOL],
"error_if_validators": (
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
TosaErrorValidator.evCondIfCondNotMatchingBool,
TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
"cond_if_binary": {
"op": Op.COND_IF,
"operands": (2, 0),
"build_fcn": (
build_cond_if_binary,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agCondIf,
),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evInputListThenGraphMismatch,
TosaErrorValidator.evInputListElseGraphMismatch,
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
TosaErrorValidator.evCondIfCondNotMatchingBool,
TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
# while_loop
"while_loop": {
"op": Op.WHILE_LOOP,
"operands": (0, 1),
"build_fcn": (
build_while_loop,
TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agWhileLoop,
),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evInputListOutputListMismatch,
TosaErrorValidator.evInputListCondGraphMismatch,
TosaErrorValidator.evInputListBodyGraphInputMismatch,
TosaErrorValidator.evInputListBodyGraphOutputMismatch,
TosaErrorValidator.evCondGraphOutputNotMatchingBool,
TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
"fft2d": {
"op": Op.FFT2D,
"operands": (2, 0),
"rank": (3, 3),
"build_fcn": (
build_fft2d,
TosaTensorGen.tgFFT2d,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agFFT2d,
),
"types": [DType.FP32],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evKernelNotPowerOfTwo,
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
},
"rfft2d": {
"op": Op.RFFT2D,
"operands": (1, 0),
"rank": (3, 3),
"build_fcn": (
build_rfft2d,
TosaTensorGen.tgRFFT2d,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agNone,
),
"types": [DType.FP32],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evKernelNotPowerOfTwo,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
},
}
class OutputShaper:
# Methods in this class compute the expected output shape and datatype
# for common classes of operations
def __init__(self):
pass
# These methods return arguments that can be used for
# creating a new output tensor
@staticmethod
def binaryBroadcastOp(ser, rng, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
shape = []
for i in range(len(a.shape)):
if a.shape[i] == 1 and error_name is None:
shape.append(b.shape[i])
else:
shape.append(a.shape[i])
fuzz_idx = rng.integers(0, len(a.shape))
if error_name == ErrorIf.DimensionMismatch:
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP16,
DType.BF16,
DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(shape, outputDType)
@staticmethod
def binaryNonBroadcastOp(ser, a, b):
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
shape = []
for i in range(len(a.shape)):
assert a.shape[i] == b.shape[i]
shape.append(a.shape[i])
return ser.addOutput(shape, a.dtype)
@staticmethod
def unaryOp(ser, rng, a, error_name=None):
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(a.shape, outputDType)
@staticmethod
def selectOp(ser, rng, cond, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
assert a.dtype == b.dtype
shape = []
for i in range(len(cond.shape)):
if cond.shape[i] == 1 and error_name is None:
shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
else:
shape.append(cond.shape[i])
fuzz_idx = rng.integers(0, len(a.shape))
if error_name == ErrorIf.DimensionMismatch:
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(shape, outputDType)
@staticmethod
def binaryComparisonOp(ser, rng, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
# Do broadcast
shape = []
for i in range(len(a.shape)):
if a.shape[i] == 1 and len(b.shape) > i:
shape.append(b.shape[i])
else:
shape.append(a.shape[i])
fuzz_idx = rng.integers(0, len(a.shape))
if error_name == ErrorIf.DimensionMismatch:
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.BOOL
return ser.addOutput(shape, outputDType)
@staticmethod
def reduceOp(ser, rng, a, axis, error_name=None):
shape = a.shape.copy()
if error_name not in [
ErrorIf.AxisSmallerZero,
ErrorIf.AxisLargerRank,
ErrorIf.ShapeOfAxisNotOne,
]:
shape[axis] = 1
if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
shape[axis] = rng.integers(2, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(shape, outputDType)
@staticmethod
def argmaxOp(ser, rng, a, axis, error_name=None):
shape = a.shape.copy()
if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
del shape[axis]
if error_name == ErrorIf.ArgmaxOutputRankMismatch:
remove = rng.choice([True, False])
if remove and len(shape) > 1:
del shape[0]
else:
shape.append(1)
elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
for i in range(len(shape)):
shape[i] = shape[i] + rng.integers(1, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.INT32
return ser.addOutput(shape, outputDType)
@staticmethod
def conv2dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
# IFM: NHWC
# Filter: OHWI
# OFM: NHWC
h = (
ifm.shape[1]
- 1
+ padding[0]
+ padding[1]
- (filter.shape[1] - 1) * dilations[0]
) // strides[0] + 1
w = (
ifm.shape[2]
- 1
+ padding[2]
+ padding[3]
- (filter.shape[2] - 1) * dilations[1]
) // strides[1] + 1
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 3]:
h = h + (rng.choice(choices) * strides[0])
if change in [2, 3]:
w = w + (rng.choice(choices) * strides[1])
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def conv3dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
# IFM: NDHWC
# Filter: ODHWI
# OFM: NDHWC
d = (
ifm.shape[1]
- 1
+ padding[0]
+ padding[1]
- (filter.shape[1] - 1) * dilations[0]
) // strides[0] + 1
h = (
ifm.shape[2]
- 1
+ padding[2]
+ padding[3]
- (filter.shape[2] - 1) * dilations[1]
) // strides[1] + 1
w = (
ifm.shape[3]
- 1
+ padding[4]
+ padding[5]
- (filter.shape[3] - 1) * dilations[2]
) // strides[2] + 1
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3, 4]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 4]:
d = d + (rng.choice(choices) * strides[0])
if change in [2, 4]:
h = h + (rng.choice(choices) * strides[1])
if change in [3, 4]:
w = w + (rng.choice(choices) * strides[2])
ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def depthwiseConv2dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
# IFM: NHWC
# Filter: HWCM
# OFM: NHW C*M
h = (
ifm.shape[1]
- 1
+ padding[0]
+ padding[1]
- (filter.shape[0] - 1) * dilations[0]
) // strides[0] + 1
w = (
ifm.shape[2]
- 1
+ padding[2]
+ padding[3]
- (filter.shape[1] - 1) * dilations[1]
) // strides[1] + 1
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 3]:
h = h + (rng.choice(choices) * strides[0])
if change in [2, 3]:
w = w + (rng.choice(choices) * strides[1])
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
# input: NHWC
if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
# If an incorrect stride is used set dimensions to 1, test is invalid anyway.
h = 1
w = 1
else:
h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
if error_name == ErrorIf.PoolingOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of stride to not hit non-integer error case
if change in [1, 3]:
h = h + (rng.choice(choices) * stride[0])
if change in [2, 3]:
w = w + (rng.choice(choices) * stride[1])
ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = ifm.dtype
return ser.addOutput(ofm_shape, outputDType)
@staticmethod
def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
# input: N, IC
# filter: OC, IC
# output: N, OC
output_shape = [input.shape[0], filter.shape[0]]
# Validated in arg_gen (also invalidated for ErrorIf)
out_dtype = accum_dtype
return ser.addOutput(output_shape, out_dtype)
@staticmethod
def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
# a: N, H, C
# b: N, C, W
# out: N, H, W
output_shape = [a.shape[0], a.shape[1], b.shape[2]]
if error_name == ErrorIf.WrongOutputType:
if a.dtype == DType.INT8:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
)
elif a.dtype == DType.INT16:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.FP32,
DType.FP16,
DType.BF16,
)
elif (
a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
):
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
)
out_dtype = rng.choice(a=incorrect_types)
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = accum_dtype # Validated in arg_gen
return ser.addOutput(output_shape, out_dtype)
@staticmethod
def concatOp(ser, rng, axis, *a, error_name=None):
input1 = a[0]
remaining_inputs = a[1:]
# calculate the output shape, if possible, otherwise just use the first input shape
output_shape = input1.shape.copy()
if not (
# unable to concat tensors of different ranks
error_name == ErrorIf.ConcatInputRankMismatch
# unable to concat tensors along an invalid axis
or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
):
for tensor in remaining_inputs:
output_shape[axis] += tensor.shape[axis]
if error_name == ErrorIf.ConcatShapeSumMismatch:
output_shape[axis] += rng.integers(5, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = {
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = input1.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def padOp(ser, rng, a, padding, error_name=None):
output_shape = a.shape.copy()
for i in range(len(output_shape)):
output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
if error_name == ErrorIf.PadOutputShapeMismatch:
bad_dim = rng.choice(range(len(output_shape)))
output_shape[bad_dim] -= rng.choice([1, 2])
elif error_name == ErrorIf.RankMismatch:
output_shape = get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def dimOp(ser, rng, a, axis, error_name=None):
output_shape = []
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.SHAPE
return ser.addOutput(output_shape, outputDType)
@staticmethod
def reshapeOp(ser, rng, a, shape, error_name=None):
output_shape = shape.copy()
if error_name == ErrorIf.TensorSizeInputOutputMismatch:
for i in range(len(output_shape)):
output_shape[i] = output_shape[i] + rng.integers(1, 10)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def sliceOp(ser, rng, input, start, size, error_name=None):
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = input.dtype
output_shape = size.copy()
if error_name == ErrorIf.SizeOutputShapeMismatch:
for index in range(len(output_shape)):
if output_shape[index] <= 2:
output_shape[index] = output_shape[index] + rng.choice([1, 2])
else:
output_shape[index] = output_shape[index] + rng.choice(
[-2, -1, 1, 2]
)
elif error_name == ErrorIf.InputSizeStartLengthMismatch:
output_shape = input.shape.copy()
elif error_name == ErrorIf.RankMismatch:
output_shape = get_rank_mismatch_shape(rng, output_shape)
return ser.addOutput(output_shape, outputDType)
@staticmethod
def tileOp(ser, rng, a, multiples, error_name=None):
output_shape = a.shape.copy()
assert len(multiples) == len(output_shape)
for i in range(len(output_shape)):
output_shape[i] = a.shape[i] * multiples[i]
if error_name == ErrorIf.RankMismatch:
output_shape = get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def transposeOp(ser, rng, a, perms, error_name=None):
output_shape = a.shape.copy()
assert len(perms) == len(output_shape)
if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
for i in range(len(output_shape)):
output_shape[i] = a.shape[perms[i]]
if error_name == ErrorIf.TensorSizeInputOutputMismatch:
for i in range(len(output_shape)):
output_shape[i] += rng.integers(1, 10)
elif error_name == ErrorIf.RankMismatch:
output_shape = get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = a.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def gatherOp(ser, rng, values, indices, error_name=None):
if error_name != ErrorIf.WrongRank:
assert len(values.shape) == 3
assert len(indices.shape) == 2
assert values.shape[0] == indices.shape[0]
output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = values.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def scatterOp(ser, rng, values_in, indices, input, error_name=None):
if error_name != ErrorIf.WrongRank:
assert len(values_in.shape) == 3
assert len(indices.shape) == 2
assert len(input.shape) == 3
assert values_in.shape[0] == indices.shape[0] # N
assert input.shape[1] == indices.shape[1] # W
assert values_in.shape[2] == input.shape[2] # C
output_shape = values_in.shape
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = values_in.dtype
return ser.addOutput(output_shape, outputDType)
@staticmethod
def tableOp(ser, rng, input, error_name=None):
# Same shape as the input, dtype dependent on input dtype
if error_name != ErrorIf.WrongInputType:
assert input.dtype == DType.INT16 or input.dtype == DType.INT8
output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = [
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.FP16,
DType.BF16,
]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(input.shape, output_dtype)
@staticmethod
def resizeOp(
serializer,
rng,
input,
mode,
scale,
offset,
border,
input_dtype,
output_dtype,
error_name=None,
):
# Calculate OH, OW
scale_y_n = scale[0]
scale_y_d = scale[1]
scale_x_n = scale[2]
scale_x_d = scale[3]
if error_name == ErrorIf.ScaleSmallerEqualZero:
scale_y_n = max(scale_y_n, 1)
scale_y_d = max(scale_y_d, 1)
scale_x_n = max(scale_x_n, 1)
scale_x_d = max(scale_x_d, 1)
oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
if error_name is not None:
# Make sure the output tensor is valid, which can occur when
# scale, offset or border have been changed for ERROR_IFs
oh = max(oh, 1)
ow = max(ow, 1)
if error_name != ErrorIf.MaxDimExceeded:
oh = min(oh, MAX_RESIZE_DIMENSION - 1)
ow = min(ow, MAX_RESIZE_DIMENSION - 1)
if error_name == ErrorIf.ResizeOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of scale_y/x_d so we don't hit non-integer error case
if change in [1, 3]:
if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
oh -= scale_y_d
assert oh > 0 # Should have been caught in agResize
else:
oh += scale_y_d
if change in [2, 3]:
if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
ow -= scale_x_d
assert ow > 0 # Should have been caught in agResize
else:
ow += scale_x_d
if error_name == ErrorIf.WrongRank:
output_dims = [
input.shape[0],
oh,
ow,
input.shape[0],
]
elif error_name == ErrorIf.BatchMismatch:
output_dims = [
input.shape[0] + rng.integers(1, 10),
oh,
ow,
input.shape[3],
]
elif error_name == ErrorIf.ChannelMismatch:
output_dims = [
input.shape[0],
oh,
ow,
input.shape[3] + rng.integers(1, 10),
]
else:
output_dims = [input.shape[0], oh, ow, input.shape[3]]
return serializer.addOutput(output_dims, output_dtype)
@staticmethod
def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
return ser.addOutput(val.shape, out_dtype)
@staticmethod
def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
if change in [1, 3]:
output_shape[1] = output_shape[1] + rng.choice(choices)
if change in [2, 3]:
output_shape[2] = output_shape[2] + rng.choice(choices)
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(output_shape, out_dtype)
@staticmethod
def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
outputs = []
assert ifm1.dtype == ifm2.dtype
input_dtype = ifm1.dtype
if error_name != ErrorIf.FFTInputShapeMismatch:
assert ifm1.shape == ifm2.shape
input_shape = ifm1.shape
if error_name != ErrorIf.WrongRank:
assert len(input_shape) == 3
output_shape = input_shape.copy()
output_dtype = input_dtype
if error_name == ErrorIf.WrongOutputType:
excludes = [DType.FP32]
wrong_dtypes = list(usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
output_shape[0] += rng.integers(1, 10)
elif error_name == ErrorIf.FFTOutputShapeMismatch:
modify_dim = rng.choice([1, 2])
output_shape[modify_dim] += rng.integers(1, 10)
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))
return outputs
@staticmethod
def rfft2dOp(serializer, rng, value, error_name=None):
outputs = []
input_shape = value.shape
if error_name != ErrorIf.WrongRank:
assert len(input_shape) == 3
output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
output_dtype = value.dtype
if error_name == ErrorIf.WrongOutputType:
excludes = [DType.FP32]
wrong_dtypes = list(usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
output_shape[0] += rng.integers(1, 10)
elif error_name == ErrorIf.FFTOutputShapeMismatch:
modify_dim = rng.choice([1, 2])
output_shape[modify_dim] += rng.integers(1, 10)
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))
return outputs