blob: 2596becc6a8ad6918b4109f29c059b533ae41d63 [file] [log] [blame]
# Copyright (c) 2021-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import itertools
import math
import numpy as np
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from serializer.tosa_serializer import DTypeNames
from tosa.DType import DType
from tosa.Op import Op
from tosa.ResizeMode import ResizeMode
# DTypeNames, DType, Op and ResizeMode are convenience variables to the
# flatc-generated types that should be enums, but aren't
class TosaQuantGen:
"""QuantizedInfo random generator helper functions.
Specify with 'qgen': in the operator defintion.
"""
def __init__(self):
pass
@staticmethod
def getZeroPoint(testGen, dtype, error_name=None):
if dtype == DType.INT8:
return testGen.randInt(-128, 128)
elif dtype == DType.UINT8:
return testGen.randInt(0, 256)
elif error_name in [
ErrorIf.InputZeroPointNotZero,
ErrorIf.WeightZeroPointNotZero,
ErrorIf.OutputZeroPointNotZero,
]:
zero_point = testGen.randInt(-128, 128)
if zero_point == 0:
zero_point = 1
return zero_point
return 0
@staticmethod
def qgUnary(testGen, op, dtype, error_name=None):
if error_name == ErrorIf.InputZeroPointNotZero:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
TosaQuantGen.getZeroPoint(testGen, dtype),
]
elif error_name == ErrorIf.OutputZeroPointNotZero:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtype),
TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
]
else:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtype),
TosaQuantGen.getZeroPoint(testGen, dtype),
]
return qinfo
@staticmethod
def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
if isinstance(dtype_or_dtypeList, list):
# a list of [input, weights, accumulator] dtypes
dtypeList = dtype_or_dtypeList
else:
# an int, [input, weights, accumulator] dtypes are the same
dtypeList = [dtype_or_dtypeList] * 3
if error_name == ErrorIf.InputZeroPointNotZero:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
]
elif error_name == ErrorIf.WeightZeroPointNotZero:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
]
else:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
]
return qinfo
@staticmethod
def qgMatmul(testGen, op, dtype, error_name=None):
if error_name == ErrorIf.InputZeroPointNotZero:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
]
else:
qinfo = [
TosaQuantGen.getZeroPoint(testGen, dtype),
TosaQuantGen.getZeroPoint(testGen, dtype),
]
return qinfo
@staticmethod
def computeMultiplierAndShift(scaleFp, scale32):
# Derived from computeMultiplierAndShiftTosaScale32
# Provide a floating-point scaling factor and the scale32 parameter
# to compute the multiplier and shift
if scale32:
scaleBits = 31
else:
scaleBits = 15
m, shift = math.frexp(scaleFp)
if scaleFp < 0.0:
m = -m
multiplier = round(m * (1 << scaleBits))
assert multiplier <= (1 << scaleBits)
if multiplier == (1 << scaleBits):
multiplier = multiplier // 2
shift = shift + 1
shift = (-shift) + scaleBits
# print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
# scaleFp, scaleBits, m, multiplier, shift))
# Adjust multiplier such that shift is in allowed value range.
if shift == 0:
multiplier = multiplier // 4
shift = shift + 2
elif shift == 1:
multiplier = multiplier // 2
shift = shift + 1
elif shift == 63:
multiplier = multiplier * 2
shift = shift - 1
assert multiplier <= (1 << scaleBits)
assert shift >= 2 and shift <= 62
return multiplier, shift
class TosaTensorGen:
"""Tensor generators create a shape list for the placeholder and const tensor
data operands for the operator.
The actual random data is generated separately for each test.
"""
def __init__(self):
pass
@staticmethod
def tgBasic(testGen, opName, rank, error_name=None):
pl, const = opName["operands"]
shape = testGen.makeShape(rank)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
shape_list = []
for i in range(pl + const):
shape_list.append(shape.copy())
if error_name == ErrorIf.RankMismatch:
if rank == 1 and i != 1:
shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
elif i != 1:
shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
return shape_list
@staticmethod
def tgNHWC(testGen, opName, rank, error_name=None):
pl, const = opName["operands"]
if error_name != ErrorIf.WrongRank:
assert rank == 4
shape = testGen.makeShape(rank)
# Constrict the batch size?
if testGen.args.max_batch_size:
shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name and error_name != ErrorIf.MaxDimExceeded:
shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
shape_list = []
for i in range(pl + const):
shape_list.append(shape.copy())
return shape_list
@staticmethod
def tgScatter(testGen, opName, rank, error_name=None):
pl, const = opName["operands"]
assert pl == 2
assert const == 0
if error_name != ErrorIf.WrongRank:
assert rank == 3
values_in_shape = testGen.makeShape(rank)
# ignore max batch size if target shape is set
if testGen.args.max_batch_size and not testGen.args.target_shapes:
values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
W = testGen.randInt(
testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
)
# Constrict W if one dimension is too large to keep tensor size reasonable
if max(values_in_shape) > 5000:
W = testGen.randInt(0, 16)
input_shape = [values_in_shape[0], W, values_in_shape[2]]
shape_list = []
shape_list.append(values_in_shape.copy())
shape_list.append(input_shape.copy())
return shape_list
@staticmethod
def tgBroadcastFuzz(testGen, op, rank, error_name=None):
shape = testGen.makeShape(rank)
pl, const = op["operands"]
shape_list = []
# Choose one of the inputs to broadcast
# Note: Simplifies OutputShaper code if we don't change first shape for errors
bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
for i in range(pl + const):
shape_bcast = shape.copy()
# If the chosen input, pick a random index to broadcast
if i == bcast_idx:
fuzz_idx = testGen.randInt(0, rank)
if error_name == ErrorIf.DimensionMismatch:
shape_bcast[fuzz_idx] += 1
elif error_name == ErrorIf.RankMismatch:
# Add one rank to the shape (or more for rank of 1)
extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
shape_bcast = np.concatenate(
(shape_bcast, testGen.makeShape(extra_ranks))
)
if rank != 1:
# Either keep the extra rank, or remove it
new_len = testGen.rng.choice([-2, len(shape_bcast)])
shape_bcast = shape_bcast[:new_len]
else:
shape_bcast[fuzz_idx] = 1
shape_list.append(shape_bcast)
return shape_list
@staticmethod
def tgConv2D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
if error_name != ErrorIf.WrongRank:
assert rank == 4
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
# Constrict the batch size?
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
ifm_shape, max_dim=24, max_items=10000
)
# Get the filter height/width from the operator parameters
filter_hw = op["filter"]
# Generate a random OFM depth
ofm_depth = testGen.makeShape(1)[0]
# The filter dimensions are OHWI
filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
# The bias is OC
bias_shape = np.asarray([ofm_depth])
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
def tgConv3D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
if error_name != ErrorIf.WrongRank:
assert rank == 5
# IFM dimensions are NDHWC
ifm_shape = testGen.makeShape(rank)
# Constrict the batch size?
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
ifm_shape, max_dim=24, max_items=10000
)
# Get the filter depth/height/width from the operator parameters
filter_dhw = op["filter"]
# Generate a random OFM channel
ofm_channel = testGen.makeShape(1)[0]
# The filter dimensions are ODHWI
filter_shape = np.asarray(
[ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
)
# The bias is OC
bias_shape = np.asarray([ofm_channel])
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
def tgTransposeConv2D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
if error_name != ErrorIf.WrongRank:
assert rank == 4
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
# Constrict the batch size?
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
ifm_shape, max_dim=24, max_items=10000
)
# Get the filter height/width from the operator parameters
filter_hw = op["filter"]
# Generate a random OFM depth
ofm_depth = testGen.makeShape(1)[0]
# The filter dimensions are OHWI
filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
# The bias is OC
bias_shape = np.asarray([ofm_depth])
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
if error_name != ErrorIf.WrongRank:
assert rank == 4
assert pl == 1 and const == 2
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
# Constrict the batch size?
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
ifm_shape, max_dim=24, max_items=10000
)
# Get the filter height/width from the operator parameters
# Filter is KH, HW, C, M
filter_hw = op["filter"]
# Generate a random OFM depth, but don't let it get too big because
# the output depth is M * C
filter_m = (
testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
) + 1
# The filter dimensions are HWCM
filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
# The bias is M * C
bias_shape = np.asarray([ifm_shape[3] * filter_m])
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
def tgFullyConnected(testGen, op, rank, error_name=None):
pl, const = op["operands"]
if error_name != ErrorIf.WrongRank:
assert rank == 2
input_shape = testGen.makeShape(rank)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
filter_oc = testGen.rng.integers(
low=testGen.args.tensor_shape_range[0],
high=testGen.args.tensor_shape_range[1],
size=1,
)[0]
filter_shape = np.asarray([filter_oc, input_shape[1]])
bias_shape = np.asarray([filter_oc])
return [input_shape, filter_shape, bias_shape]
@staticmethod
def tgMatmul(testGen, op, rank, error_name=None):
pl, const = op["operands"]
if error_name != ErrorIf.WrongRank:
assert rank == 3
assert pl == 2 and const == 0
a_shape = testGen.makeShape(rank)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
# Get a random number for b_oc even if target shape is defined
b_oc = np.int32(
testGen.rng.integers(
low=testGen.args.tensor_shape_range[0],
high=testGen.args.tensor_shape_range[1],
size=1,
)
)[0]
# If N or H is large let b_oc be 1 to reduce output tensor size
if max(a_shape) > 1000:
b_oc = 1
b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
return [a_shape, b_shape]
@staticmethod
def tgConcat(testGen, opName, rank, error_name=None):
pl, const = opName["operands"]
shape = testGen.makeShape(rank)
# Create extra tensors to concat.
# Take into account value of pl when getting maximum number of concats
num_tensors = testGen.randInt(0, 4)
shape_list = []
for i in range(pl + const + num_tensors):
if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
remove = testGen.rng.choice([True, False])
wrongShape = shape.copy()
if remove and len(shape) > 1:
wrongShape = wrongShape[1:]
else:
wrongShape = list(wrongShape)
wrongShape.append(testGen.rng.integers(1, 10))
shape_list.append(wrongShape)
else:
shape_list.append(shape.copy())
return shape_list
@staticmethod
def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
if error_name in [
ErrorIf.AxisSmallerZero,
ErrorIf.AxisLargerRank,
ErrorIf.ConcatInputRankMismatch,
]:
return shapeList
# Split concat shape along axis to allow for multiple const inputs
# without making too many large tensors
if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
# If axis can't be split we still need to invalidate other dimensions
if error_name == ErrorIf.ConcatInputDimMismatch:
for shape in shapeList[1:]:
# Negative test shapeLists are created individually for each test,
# so no need to copy the shape before altering it.
shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
return shapeList
# Create copy of shape we are going to split (so we don't alter shapeList)
shape = shapeList[0].copy()
# Add original shape as first input
new_shapeList = [shape.copy()]
length_on_axis = shape[axis]
remaining_length = length_on_axis
for i in range(len(shapeList) - 2):
# Calculate split on axis and remaining value
split_shape_val = int(shape[axis] / 2)
remaining_length = remaining_length - split_shape_val
# Append new shape, and set remaining shape
shape[axis] = split_shape_val
new_shapeList.append(shape.copy())
# invalidate dimensions
if error_name == ErrorIf.ConcatInputDimMismatch:
shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
else:
shape[axis] = remaining_length
if i == len(shapeList) - 3:
new_shapeList.append(shape.copy())
return new_shapeList
class TosaTensorValuesGen:
"""Tensor Value generators create the random data for each test."""
def __init__(self):
pass
@staticmethod
def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
tens = []
tens.extend(
testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
)
tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
return tens
@staticmethod
def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
pCount, cCount = op["operands"]
assert (
pCount == 1 and cCount == 0
), "Op.NEGATE must have 1 placeholders, 0 consts"
# Must create tensors with values within accumulator (int32) negatable
# range
max_val = (1 << 31) - 1
min_val = -max_val
arr = np.int32(
testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
)
placeholders = []
placeholders.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
)
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
# Make sure the operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
placeholders = []
add = op["op"] == Op.ADD
a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
if add:
res_arr = np.add(a_arr, b_arr, dtype=np.int64)
else:
res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
# Work out the saturation limits
max_i32 = (1 << 31) - 1
min_i32 = -(1 << 31)
max_arr = np.full(shapeList[1], max_i32)
min_arr = np.full(shapeList[1], min_i32)
# Find how much values exceed the maximum/minimums
sat_max_arr = np.maximum(res_arr - max_arr, 0)
sat_min_arr = np.minimum(res_arr - min_arr, 0)
if not add:
# Swap saturation values and negate values as we need to perform opposite operations
sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
# Create new array of unsaturated values by clipping values as needed
b_unsat_arr = b_arr
if (sat_max_arr != 0).any():
# Clip values that cause saturation
b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
# Reduce axes in unsaturated tensor to match original tensor
for axis, dim in enumerate(b_arr.shape):
if dim != b_unsat_arr.shape[axis]:
assert (
dim == 1
), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
if (sat_min_arr != 0).any():
# Clip values that cause saturation
b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
# Reduce axes in unsaturated tensor to match original tensor
for axis, dim in enumerate(b_arr.shape):
if dim != b_unsat_arr.shape[axis]:
assert (
dim == 1
), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
placeholders.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
)
placeholders.append(
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
)
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgCondIfWhileLoop(
testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
if dtypeList[0] in (
DType.INT32,
DType.INT16,
DType.INT8,
):
# Limit input tensors with cond_if_binary or while_loop to stop
# saturation of add/sub ops with int32 and keep all logical shift
# values between 0 to 31 for int16 or int8
pCount, cCount = op["operands"]
pRemain = pCount
placeholders = []
for idx, shape in enumerate(shapeList[:]):
if dtypeList[0] == DType.INT32:
arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
else:
arr = np.int32(
testGen.rng.integers(low=0, high=32, size=shapeList[idx])
)
if pRemain > 0:
placeholders.append(
testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
)
pRemain -= 1
else:
placeholders.append(
testGen.ser.addConst(shape, dtypeList[idx], arr)
)
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgArithmeticRightShift(
testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
pCount, cCount = op["operands"]
# Force value of operand[1] to be within [0, num_bits]
assert (
pCount == 2 and cCount == 0
), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
placeholders = []
for idx, shape in enumerate(shapeList[:]):
if idx == 1:
if dtypeList[idx] == DType.INT8:
arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
elif dtypeList[idx] == DType.INT16:
arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
elif dtypeList[idx] == DType.INT32:
arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
elif error_name == ErrorIf.WrongInputType:
arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
else:
raise Exception("OpArithmeticRightShift: invalid input dtype")
else:
arr = testGen.getRandTensor(shape, dtypeList[idx])
placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
return placeholders
@staticmethod
def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
# Set datatype of condition tensor to boolean
dtypeList[0] = DType.BOOL
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
), "Op.INTDIV must have 2 placeholders, 0 consts"
placeholders = []
# Two invalid cases for Op.INTDIV:
# 1. divisor == 0
# 2. dividend == -(1<<31) and divisor == -1
while True:
dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
if (divisor_arr == 0).any():
continue
if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
continue
break
placeholders.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
)
placeholders.append(
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
)
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
if dtypeList[0] == DType.FLOAT:
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
# Make sure multiply result in int32 range
shift = testArgs[0]
if dtypeList[0] == DType.INT8:
num_bits = 8
elif dtypeList[0] == DType.INT16:
num_bits = 16
elif dtypeList[0] == DType.INT32:
num_bits = 32
elif error_name == ErrorIf.WrongInputType:
num_bits = 8
else:
raise Exception("OpMul: invalid input dtype")
for idx, shape in enumerate(shapeList[:]):
low = -(2 ** (num_bits - 1))
high = (2 ** (num_bits - 1)) - 1
a_arr = np.int32(
testGen.rng.integers(low=low, high=high, size=shapeList[0])
)
b_arr = np.int32(
testGen.rng.integers(low=low, high=high, size=shapeList[1])
)
i = 0
while True:
a_arr_64 = a_arr.astype(np.int64)
b_arr_64 = b_arr.astype(np.int64)
if shift > 0:
rounding = 1 << (shift - 1)
result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
else:
result_arr = a_arr_64 * b_arr_64
if (result_arr > -(2**31)).all() and (
result_arr <= ((2**31) - 1)
).all():
break
i = i + 1
a_arr = a_arr // 2
b_arr = b_arr // 2
placeholders.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
)
placeholders.append(
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
)
tens.extend(placeholders)
return tens
else:
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
count = len(shapeList) - testGen.args.num_const_inputs_concat
if count < 1:
count = 1
if testGen.args.num_const_inputs_concat == 0:
count = len(shapeList)
# Ensure axis is an int
testArgs[0] = int(testArgs[0])
shapeList = TosaTensorGen.tgConcatConstInput(
testGen, shapeList, testArgs[0], error_name
)
tens = []
tens.extend(
testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
)
tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
return tens
@staticmethod
def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
placeholders = []
placeholders.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
)
placeholders.append(
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
)
return placeholders
@staticmethod
def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
), "Op.EQUAL must have 2 placeholders, 0 consts"
a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
# Using random numbers means that it will be very unlikely that
# there are any matching (equal) values, therefore force that
# there are twice the number of matching values as the tensor rank
for num in range(0, len(shapeList[0]) * 2):
a_index = []
b_index = []
# Choose an index in each axis for the whole shape
for axis in range(0, len(shapeList[0])):
# Index can be up to the largest dimension in both shapes
index = np.int32(
testGen.rng.integers(
0, max(shapeList[0][axis], shapeList[1][axis])
)
)
# Reduce the index down to a shape's dim for broadcasting
a_index.append(min(shapeList[0][axis] - 1, index))
b_index.append(min(shapeList[1][axis] - 1, index))
a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
placeholders = []
placeholders.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
)
placeholders.append(
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
)
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32:
pCount, cCount = op["operands"]
assert (
pCount == 1 and cCount == 0
), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
# Limit values so that the sum cannot exceed the range of an int32 during
# summation of any axis
range_val = int((1 << 31) / max(shapeList[0]))
values_arr = np.int32(
testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
)
placeholders = []
placeholders.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
)
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
testGen, op, dtypeList, shapeList, testArgs, error_name
)
class TosaArgGen:
"""Argument generators create exhaustive or random lists of attributes for
operators that take attributes or other parameters.
The return value is a list of (descriptive_name, [arglist]) tuples where
the descriptive_name is appended to the test name and the arglist is expanded
as arguments to the operator build function.
"""
def __init__(self):
pass
@staticmethod
def agNone(testGen, opName, shapeList, dtype, error_name=None):
"""A trivial argument generator for operators that don't take any
non-tensor arguments"""
return [("", [])]
@staticmethod
def agAxis(testGen, opName, shapeList, dtype, error_name=None):
"""Build the axis argument for operators that take a single axis"""
axes = []
shape = shapeList[0]
if error_name == ErrorIf.AxisSmallerZero:
small_axis = testGen.rng.integers(-5, 0)
axes.append(("axis{}".format(small_axis), [small_axis]))
elif error_name == ErrorIf.AxisLargerRank:
large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
axes.append(("axis{}".format(large_axis), [large_axis]))
else:
for a in range(0, len(shape)):
axes.append(("axis{}".format(a), [a]))
return axes
@staticmethod
def agConv(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
# determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
k = [int(x) for x in opName.split("_")[-1].split("x")]
# Check the rank
rank = 5 if opName.startswith("conv3d") else 4
if error_name != ErrorIf.WrongRank:
assert len(ifm_shape) == rank
assert len(filter_shape) == rank
# kernel rank omits batch and channels
k_rank = rank - 2
assert len(k) == k_rank
# Generate comprehensive argument lists
# - except for named errors, which use specific invalid value(s)
if error_name == ErrorIf.PadSmallerZero:
p_vals = [testGen.rng.choice(range(-5, 0))]
else:
p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
if error_name == ErrorIf.StrideSmallerOne:
# Can't use stride=0, as it is used to derive output shape, as a divisor
s_vals = [testGen.rng.choice(range(-5, 0))]
else:
# Stride must be greater than 1 to force non-integer error
startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * k_rank))}
if error_name == ErrorIf.DilationSmallerOne:
d_vals = [testGen.rng.choice(range(-5, 1))]
else:
d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
if not error_name and testGen.args.oversize:
# add some oversize argument values
if max(ifm_shape) < 64:
bigPadding = 9
paddings.update(
{x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
)
bigStride = 8
strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
bigDilation = 7
dilations.update(
{x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
)
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
sparsity_factor = 2 if error_name else 120
sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
# If there are only a small number of tests, just select them all
if sparsity < 13:
sparsity = 1
# To get a variety of parameter combinations sparsity should not be a
# multiple of 2, 3 or 5
while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
sparsity += 1
n = 0
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
for d in sorted(list(dilations)):
if (
n % sparsity == 0
# padding must not exceed the kernel size ?
# and p[0] < k[0] and p[1] < k[0]
# and p[2] < k[1] and p[3] < k[1]
# and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
# the padded shape must exceed the kernel size
and (ifm_shape[1] + p[0] + p[1]) > k[0]
and (ifm_shape[2] + p[2] + p[3]) > k[1]
and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
# the padded shape must exceed the dilation
and (ifm_shape[1] + p[0] + p[1]) > d[0]
and (ifm_shape[2] + p[2] + p[3]) > d[1]
and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
):
remainders = []
for index in range(k_rank):
pad_offset = index * 2
remainders.append(
(
ifm_shape[index + 1]
- 1
+ p[pad_offset]
+ p[pad_offset + 1]
- (k[index] - 1) * d[index]
)
% s[index]
)
if (
# the parameters must produce integer exact output
error_name != ErrorIf.ConvOutputShapeNonInteger
and max(remainders) == 0
) or (
error_name == ErrorIf.ConvOutputShapeNonInteger
and max(remainders) > 0
):
arg_list.append(
(
"st{}_pad{}_dilat{}".format(
"".join([str(x) for x in s]),
"".join([str(x) for x in p]),
"".join([str(x) for x in d]),
),
[s, p, d],
)
)
n += 1
return arg_list
@staticmethod
def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
# Must be rank 4
if error_name != ErrorIf.WrongRank:
assert len(ifm_shape) == 4
assert len(filter_shape) == 4
# Generate comprehensive argument lists
# - except for named errors, which use specific invalid value(s)
if error_name == ErrorIf.PadSmallerZero:
p_vals = [testGen.rng.choice(range(-5, 0))]
else:
p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
paddings = {x for x in itertools.product(*([p_vals] * 4))}
if error_name == ErrorIf.StrideSmallerOne:
# Can't use stride=0, as it is used to derive output shape, as a divisor
s_vals = [testGen.rng.choice(range(-5, 0))]
else:
s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * 2))}
if not error_name and testGen.args.oversize:
# add some oversize argument values
if max(ifm_shape) < 64:
bigPadding = 9
paddings.update(
{x for x in itertools.product(*([[0, bigPadding]] * 4))}
)
bigStride = 8
strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
sparsity_factor = 2 if error_name else 10
sparsity = len(paddings) * len(strides) // sparsity_factor + 1
# If there are only a small number of tests, just select them all
if sparsity < 13:
sparsity = 1
# To get a variety of parameter combinations sparsity should not be a
# multiple of 2, 3 or 5
while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
sparsity += 1
n = 0
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
if n % sparsity == 0:
# Determine the output shape
oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1]
ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2]
os = [ifm_shape[0], oh, ow, filter_shape[0]]
arg_list.append(
(
"st{}_pad{}_os{}".format(
"".join([str(x) for x in s]),
"".join([str(x) for x in p]),
"x".join([str(x) for x in os]),
),
[s, p, os],
)
)
n += 1
return arg_list
@staticmethod
def agPad(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
rank = len(shapeList[0])
# Exhaustively test combinations of padding on each side of each dimension
# - the range of padding values is defined by pad_min and pad_max
# - for padding >9, the name format needs to be more distinctive
pad_min, pad_max = 0, 1
pad_values = [x for x in range(pad_min, pad_max + 1)]
if error_name == ErrorIf.PadSmallerZero:
pad_values = [x for x in range(-2, 0)]
axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
shape_pad_values = itertools.product(*([axis_pad_values] * rank))
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
elif dtype == DType.FLOAT:
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
return []
for paddings in shape_pad_values:
name = "pad"
for r in range(rank):
before, after = paddings[r]
name = f"{name}{before}{after}"
arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
return arg_list
@staticmethod
def agPooling(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
shape = shapeList[0]
if error_name != ErrorIf.WrongRank:
assert len(shape) == 4
# Generate comprehensive argument lists
p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
paddings = {x for x in itertools.product(*([p_vals] * 4))}
# Stride must be greater than 1 to force non-integer error
startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * 2))}
k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
kernels = {x for x in itertools.product(*([k_vals] * 2))}
if testGen.args.oversize:
# add some oversize argument values
bigStride = 7
strides.update(
{x for x in itertools.product(*([[startStride, bigStride]] * 2))}
)
bigKernel = 9
kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
if max(shape) < 64:
# padding must be less than the kernel size
bigPadding = bigKernel - 1
paddings.update(
{x for x in itertools.product(*([[0, bigPadding]] * 4))}
)
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
sparsity_factor = 2 if error_name else 500
sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
n = 0
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
for k in sorted(list(kernels)):
if error_name in [
ErrorIf.StrideSmallerOne,
ErrorIf.KernelSmallerOne,
ErrorIf.PadSmallerZero,
ErrorIf.PadLargerEqualKernel,
]:
sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
testGen, error_name, s, p, k
)
if None not in [sNew, pNew, kNew] and n % sparsity == 0:
arg_list.append(
(
"st{}_kern{}_pad{}".format(
"".join([str(x) for x in sNew]),
"".join([str(x) for x in kNew]),
"".join([str(x) for x in pNew]),
),
[sNew, pNew, kNew],
)
)
elif (
n % sparsity == 0
# padding must not exceed the kernel size
and p[0] < k[0]
and p[1] < k[0]
and p[2] < k[1]
and p[3] < k[1]
# the padded shape must exceed the kernel size
and (shape[1] + p[0] + p[1]) > k[0]
and (shape[2] + p[2] + p[3]) > k[1]
):
remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
if (
# the parameters must produce integer exact output
error_name != ErrorIf.PoolingOutputShapeNonInteger
and remainder_h == 0
and remainder_w == 0
) or (
error_name == ErrorIf.PoolingOutputShapeNonInteger
and (remainder_h != 0 or remainder_w != 0)
):
arg_list.append(
(
"st{}_kern{}_pad{}".format(
"".join([str(x) for x in s]),
"".join([str(x) for x in k]),
"".join([str(x) for x in p]),
),
[s, p, k],
)
)
n += 1
return arg_list
@staticmethod
def agCast(testGen, opName, shapeList, inDtype, error_name=None):
arg_list = []
# Enumerate the output types here
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
elif inDtype == DType.INT16:
dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
elif inDtype == DType.INT32:
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FLOAT:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
else:
raise Exception("Unexpected input dtype: {}".format(inDtype))
for dtype in dtypeList:
arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
return arg_list
@staticmethod
def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
arg_list = []
# Enumerate the output types here
for outDtype in [
DType.UINT8,
DType.INT8,
DType.INT16,
DType.INT32,
DType.UINT16,
]:
if (
outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
and error_name == ErrorIf.OutputZeroPointNotZero
):
continue
if (
outDtype != DType.UINT16
and error_name == ErrorIf.U16OutputZeroPointNotValid
) or (
inDtype != DType.UINT16
and error_name == ErrorIf.U16InputZeroPointNotValid
):
# ErrorIfs only valid with UINT16
continue
if (
inDtype == DType.UINT8
and outDtype not in [DType.INT8, DType.INT16]
and error_name != ErrorIf.WrongOutputType
):
# The only output dtypes for UINT8 are INT8/INT16, skip all others
continue
if (
inDtype not in [DType.INT8, DType.INT16]
and outDtype == DType.UINT8
and error_name != ErrorIf.WrongOutputType
):
# The only input dtypes for UINT8 are INT8/INT16, skip all others
continue
if (
inDtype == DType.UINT16
and outDtype != DType.INT16
and error_name != ErrorIf.WrongOutputType
):
# The only output dtype for UINT16 is INT16, skip all others
continue
if (
inDtype != DType.INT16
and outDtype == DType.UINT16
and error_name != ErrorIf.WrongOutputType
):
# The only input dtype for UINT16 is INT16, skip all others
continue
if (
error_name == ErrorIf.WrongOutputType
and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
):
continue
for scale32 in [False, True]:
if error_name == ErrorIf.ScaleTrue and not scale32:
continue
elif error_name == ErrorIf.ScaleNotTrue and scale32:
continue
for double_round in [False, True]:
if error_name == ErrorIf.ScaleNotTrue and not double_round:
continue
for per_channel in [False, True]:
if (
inDtype == DType.INT48
and scale32
and error_name != ErrorIf.ScaleTrue
):
# Illegal condition. Must be scale32=False
continue
if (
double_round
and not scale32
and error_name != ErrorIf.ScaleNotTrue
):
# Illegal condition. ERROR_IF(!scale32 && double_round)
continue
arg_list.append(
(
"out{}_sc{}_dr{}_pc{}".format(
DTypeNames[outDtype],
int(scale32),
int(double_round),
int(per_channel),
),
[outDtype, scale32, double_round, per_channel],
)
)
return arg_list
@staticmethod
def agMul(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
if dtype is DType.INT32:
for p in range(testGen.args.num_rand_permutations):
shift = testGen.randInt(0, 32)
arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
else:
arg_list.append(("perm0_shift0", [0]))
return arg_list
@staticmethod
def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
arg_list.append(("roundTrue", [True]))
arg_list.append(("roundFalse", [False]))
return arg_list
# Helper function for reshape. Gets some factors of a larger number.
@staticmethod
def getFactors(val, start=1):
factors = []
for i in range(start, int(np.sqrt(val)) + 1):
if (val % i) == 0:
factors.append(i)
return factors
@staticmethod
def agReshape(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
origShape = shapeList[0]
totalElements = 1
for s in origShape:
totalElements *= s
# This code is NOT fast. Fortunately, the numbers are fairly small.
factors = TosaArgGen.getFactors(totalElements)
for p in range(testGen.args.num_rand_permutations):
newRank = testGen.randInt(1, 7)
if len(factors) < newRank:
continue
found = True
# escape_counter breaks while loop if it continues on for too long
escape_counter = 0
while found:
newShape = []
# Generate newShape ensuring it isn't a duplicate
remainingElements = totalElements
shuffledFactors = testGen.rng.permutation(factors)
for i in range(1, newRank):
# pick rank-1 factors
newShape.append(shuffledFactors[0])
remainingElements = remainingElements // shuffledFactors[0]
shuffledFactors = testGen.rng.permutation(
TosaArgGen.getFactors(remainingElements)
)
newShape.append(remainingElements)
# Check for duplicates
found = False
for name, other_shape in arg_list:
if other_shape[0] == newShape:
found = True
break
escape_counter += 1
if escape_counter >= 100:
break
if not found:
arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
return arg_list
@staticmethod
def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
if error_name == ErrorIf.IndexOutsideBounds:
incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
incorrect_small_index = range(-len(ifm_shape), 0)
permutations = [p for p in itertools.permutations(incorrect_large_index)]
permutations.extend(
[p for p in itertools.permutations(incorrect_small_index)]
)
elif error_name == ErrorIf.IndexUsedTwice:
# Create list with a duplicated index
perm_range = list(range(len(ifm_shape)))
index_choice = testGen.rng.choice(range(len(perm_range)))
perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
permutations = [p for p in itertools.permutations(perm_range)]
else:
# Get all permutations
permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
# Limit to possible permutations from shape dimension or argument setting
limit = min(len(permutations), testGen.args.num_rand_permutations)
# Get random permutation generator that uses all permutations
random_permutations = testGen.rng.permutation(permutations)
# Create list of required amount of permutations
arg_list = [
("perm{}".format(p), [random_permutations[p].tolist()])
for p in range(limit)
]
return arg_list
@staticmethod
def agSlice(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
rank = len(ifm_shape)
for p in range(testGen.args.num_rand_permutations):
start = []
size = []
valid = True
for i in range(rank):
if ifm_shape[i] > 1:
start.append(testGen.randInt(0, ifm_shape[i]))
size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
# Invalid slice size?
if size[i] == 0:
valid = False
else:
start.append(0)
size.append(1)
if valid:
# If ERROR_IF test required then incorrect start, size will be returned
start, size = TosaErrorIfArgGen.eiSliceErrorIf(
testGen, error_name, ifm_shape, start, size
)
arg_list.append(("perm{}".format(p), [start, size]))
return arg_list
@staticmethod
def agTile(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
rank = len(ifm_shape)
for p in range(testGen.args.num_rand_permutations):
# Pick a few random, but small multiple values
# because otherwise this has a tendency to generate
# enormous tensors
multiples = []
for i in range(rank):
if ifm_shape[i] > 1000:
# Multiple of 1 if ifm_shape dimension is large to reduce
# tensor size
multiples.append(1)
elif max(ifm_shape) > 1000:
multiples.append(2)
else:
multiples.append(testGen.randInt(1, 4))
arg_list.append(("perm{}".format(p), [multiples]))
return arg_list
@staticmethod
def agResize(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
def get_aspect_ratio_resize_params():
common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
aspect_ratio = testGen.rng.choice(common_aspect_ratios)
invert = testGen.rng.choice((False, True))
letterbox = testGen.rng.choice((False, True))
scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
scale_y_d = scale_x_d = 1
offset_x = offset_y = 0
if letterbox:
max_border = scale_y_n
border_y = testGen.randInt(low=0, high=max_border)
border_x = 0
else:
# Pillarboxing
border_y = 0
max_border = scale_x_n
border_x = testGen.randInt(low=0, high=max_border)
scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
offset = (offset_y, offset_x)
border = (border_y, border_x)
return scale, offset, border
def get_upscale_downscale_params():
valid_params = False
while not valid_params:
upscale = testGen.rng.choice((False, True))
# True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
origin_sampling = testGen.rng.choice((False, True))
if upscale:
shift = testGen.randInt(low=1, high=4)
scale_x_d = scale_y_d = 1
scale_x_n = scale_y_n = (
1 << shift if origin_sampling else 2 << shift
)
border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
else:
scale_x_n = 1
scale_y_n = 1
# Return list of valid scale_*_d values (max value 4) given input dim shape
def get_valid_denom(ifm_dim):
return [x for x in range(1, 5) if ifm_dim % x == 1]
# Generate list of valid downscale values and choose one randomly
valid_scale_y_ds = get_valid_denom(ifm_shape[1])
valid_scale_x_ds = get_valid_denom(ifm_shape[2])
if not valid_scale_y_ds and not valid_scale_x_ds:
# Bad parameters, skip
continue
if not valid_scale_y_ds:
scale_y_d = 1
else:
scale_y_d = testGen.rng.choice(valid_scale_y_ds)
if not valid_scale_x_ds:
scale_x_d = 1
else:
scale_x_d = testGen.rng.choice(valid_scale_x_ds)
border_x = border_y = 0
offset_y = testGen.randInt(0, 16 * scale_y_n)
offset_x = testGen.randInt(0, 16 * scale_x_n)
valid_params = True
scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
offset = (offset_y, offset_x)
border = (border_y, border_x)
return scale, offset, border
def get_rand_params():
# Scale
scale_y_n = testGen.randInt(low=1, high=(1 << 11))
scale_x_n = testGen.randInt(low=1, high=(1 << 11))
scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
# Offsets and border within the scale
offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
offset = (offset_y, offset_x)
border = (border_y, border_x)
return scale, offset, border
for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
# Exclude illegal {mode, type} configurations. Pick legal output types
if mode == ResizeMode.NEAREST and dtype == DType.INT8:
outputDTypeList = [DType.INT8]
elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
outputDTypeList = [DType.INT16]
elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
outputDTypeList = [DType.INT32]
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
outputDTypeList = [DType.INT48]
elif dtype == DType.FLOAT:
outputDTypeList = [DType.FLOAT]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors
outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
else:
continue
arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
for outputDType in outputDTypeList:
perm = 0
while perm < testGen.args.num_rand_permutations:
# Random choice of type of params we are testing
_rnd_param_fn = testGen.rng.choice(
(
get_rand_params,
get_upscale_downscale_params,
get_aspect_ratio_resize_params,
)
)
scale, offset, border = _rnd_param_fn()
# Expand params for bounds-checking
(scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
(offset_y, offset_x) = offset
(border_y, border_x) = border
# Make sure output dimensions OH and OW are integers
partial_output_y = (
(ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
)
partial_output_x = (
(ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
)
if error_name == ErrorIf.ResizeOutputShapeNonInteger:
if (
partial_output_y % scale_y_d == 0
and partial_output_x % scale_x_d == 0
):
# Skip this test as it doesn't produce NonInteger output
perm += 1
continue
else:
while partial_output_y % scale_y_d != 0:
scale_y_d -= 1
while partial_output_x % scale_x_d != 0:
scale_x_d -= 1
output_y = partial_output_y // scale_y_d + 1
output_x = partial_output_x // scale_x_d + 1
if (
output_y >= testGen.args.max_resize_output_dim
or output_x >= testGen.args.max_resize_output_dim
) and error_name is None:
# Skip positive test if output dim will be too high
# Avoid high test latency and OOM issues
perm += 1
continue
if (
output_y <= 0
or output_y >= MAX_RESIZE_DIMENSION
or output_x <= 0
or output_x >= MAX_RESIZE_DIMENSION
):
# Output dimensions out of scope
if error_name is not None and perm > 0:
# As long as we have one ERROR_IF test, don't worry
# about creating all the other permutations
perm += 1
continue
if error_name == ErrorIf.ResizeOutputShapeMismatch and (
(
output_y + scale_y_d >= MAX_RESIZE_DIMENSION
and output_y - scale_y_d < 1
)
or (
output_x + scale_x_d >= MAX_RESIZE_DIMENSION
and output_x - scale_x_d < 1
)
):
# Can't create a negative test with these params as it
# will create invalid output size
if perm > 0:
perm += 1
continue
scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
offset = [offset_y, offset_x]
border = [border_y, border_x]
# Common for all data types
if error_name is not None:
(
scale,
offset,
border,
outputDTypeNew,
) = TosaErrorIfArgGen.eiResizeErrorIf(
testGen,
error_name,
mode,
dtype,
shapeList,
outputDType,
scale,
offset,
border,
)
else:
outputDTypeNew = outputDType
arg_to_append = (
arg_str.format(
"N" if mode == ResizeMode.NEAREST else "B",
testGen.typeStr(outputDTypeNew),
scale[0],
scale[1],
scale[2],
scale[3],
offset[0],
offset[1],
border[0],
border[1],
),
[
mode,
scale,
offset,
border,
dtype,
outputDTypeNew,
],
)
if arg_to_append in arg_list:
# Skip already generated test params
continue
# Valid permutation
perm += 1
arg_list.append(arg_to_append)
return arg_list
@staticmethod
def agTable(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
if dtype == DType.INT8:
table = np.int32(
testGen.rng.integers(low=-128, high=128, size=[256])
).tolist()
else: # INT16
table = np.int32(
testGen.rng.integers(low=-32768, high=32768, size=[513])
).tolist()
# Make sure all slopes are within REQUIRE min/max 16-bit int
for idx in range(len(table) - 1):
slope = table[idx + 1] - table[idx]
# Alter the next table entry to force the slope to be ok
if slope > 32767:
table[idx + 1] -= slope - 32767
if slope < -32768:
table[idx + 1] -= slope + 32768
slope = table[idx + 1] - table[idx]
assert slope <= 32767 and slope >= -32768
arg_list.append(
(
"",
[table],
)
)
return arg_list
def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
# CondIf generates the condition values here.
# Convert to tensors in the build function, along with the
# then and else blocks
arg_list = []
for c in [False, True]:
arg_list.append(("cond{}".format(int(c)), [c]))
return arg_list
def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
# While loop: 0 iterations, 1, more than 1
arg_list = []
for iter in [0, 1, 4]:
arg_list.append(("iter{}".format(iter), [iter]))
return arg_list