blob: d4f3898328b672022cc2426d3f7da29671db67a3 [file] [log] [blame]
# Copyright (c) 2021-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import struct
import sys
from enum import IntEnum
import numpy as np
from tosa.DType import DType
# Maximum dimension size for output and inputs for RESIZE
MAX_RESIZE_DIMENSION = 16384
# Maximum rank of tensor supported by test generator.
MAX_TENSOR_RANK = 6
# Data type information dictionary
# - str: filename abbreviation
# - width: number of bytes needed for type
# - fullset: precalculated number of possible values in the data type's range, equal to 2^width
# - json: JSON type string
DTYPE_ATTRIBUTES = {
DType.BOOL: {"str": "b", "width": 1, "fullset": 2, "json": "BOOL"},
DType.INT4: {"str": "i4", "width": 4, "fullset": 16, "json": "INT4"},
DType.INT8: {"str": "i8", "width": 8, "fullset": 256, "json": "INT8"},
DType.UINT8: {"str": "u8", "width": 8, "fullset": 256, "json": "UINT8"},
DType.INT16: {"str": "i16", "width": 16, "fullset": 65536, "json": "INT16"},
DType.UINT16: {"str": "u16", "width": 16, "fullset": 65536, "json": "UINT16"},
DType.INT32: {"str": "i32", "width": 32, "fullset": 1 << 32, "json": "INT32"},
DType.INT48: {"str": "i48", "width": 48, "fullset": 1 << 48, "json": "INT48"},
DType.SHAPE: {"str": "s", "width": 64, "fullset": 1 << 64, "json": "SHAPE"},
DType.FP16: {"str": "f16", "width": 16, "fullset": 65536, "json": "FP16"},
DType.BF16: {"str": "bf16", "width": 16, "fullset": 65536, "json": "BF16"},
DType.FP32: {"str": "f32", "width": 32, "fullset": 1 << 32, "json": "FP32"},
DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "fullset": 256, "json": "FP8E4M3"},
DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "fullset": 256, "json": "FP8E5M2"},
}
class ComplianceMode(IntEnum):
"""Compliance mode types."""
EXACT = 0
DOT_PRODUCT = 1
ULP = 2
FP_SPECIAL = 3
REDUCE_PRODUCT = 4
ABS_ERROR = 5
RELATIVE = 6
class DataGenType(IntEnum):
"""Data generator types."""
PSEUDO_RANDOM = 0
DOT_PRODUCT = 1
BOUNDARY = 2
FULL_RANGE = 3
FP_SPECIAL = 4
FIXED_DATA = 5
def dtypeWidth(dtype):
"""Get the datatype width for data types"""
if dtype in DTYPE_ATTRIBUTES:
return DTYPE_ATTRIBUTES[dtype]["width"]
else:
raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
def dtypeIsFloat(dtype):
"""Is floating point data type"""
return dtype in (DType.BF16, DType.FP16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2)
def dtypeIsSupportedByCompliance(dtype):
"""Types supported by the new data generation and compliance flow."""
if isinstance(dtype, list) or isinstance(dtype, tuple):
dtype = dtype[0]
return dtype in (DType.FP32, DType.FP16)
def getOpNameFromOpListName(opName):
"""Get the op name from a TOSA_OP_LIST name that can have suffixes."""
for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
if opName.startswith(name):
return name
return opName
def valueToName(item, value):
"""Get the name of an attribute with the given value.
This convenience function is needed to print meaningful names for
the values of the tosa.Op.Op and tosa.DType.DType classes.
This would not be necessary if they were subclasses of Enum, or
IntEnum, which, sadly, they are not.
Args:
item: The class, or object, to find the value in
value: The value to find
Example, to get the name of a DType value:
name = valueToName(DType, DType.INT8) # returns 'INT8'
name = valueToName(DType, 4) # returns 'INT8'
Returns:
The name of the first attribute found with a matching value,
Raises:
ValueError if the value is not found
"""
for attr in dir(item):
if getattr(item, attr) == value:
return attr
raise ValueError(f"value ({value}) not found")
def allDTypes(*, excludes=None):
"""Get a set of all DType values, optionally excluding some values.
This convenience function is needed to provide a sequence of DType values.
This would be much easier if DType was a subclass of Enum, or IntEnum,
as we could then iterate over the values directly, instead of using
dir() to find the attributes and then check if they are what we want.
Args:
excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
Returns:
A set of DType values
"""
excludes = () if not excludes else excludes
return {
getattr(DType, t)
for t in dir(DType)
if not callable(getattr(DType, t))
and not t.startswith("__")
and getattr(DType, t) not in excludes
}
def usableDTypes(*, excludes=None):
"""Get a set of usable DType values, optionally excluding some values.
Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
addition to the excludes specified by the caller, as the serializer lib
does not support them.
If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
instead.
Args:
excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
Returns:
A set of DType values
"""
omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
omit.update(excludes if excludes else ())
return allDTypes(excludes=omit)
def product(shape):
value = 1
for n in shape:
value *= n
return value
def get_accum_dtypes_from_tgTypes(dtypes):
# Get accumulate data-types from the test generator's defined types
assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
input_dtype = dtypes[0]
output_dtype = dtypes[-1]
# by default, accum_dtypes contains only output_dtype
accum_dtypes = [output_dtype]
if input_dtype == DType.FP16 and output_dtype == DType.FP16:
accum_dtypes = [DType.FP16, DType.FP32]
elif output_dtype == DType.BF16:
accum_dtypes = [DType.FP32]
return accum_dtypes
def get_wrong_output_type(op_name, rng, input_dtype):
if op_name == "fully_connected" or op_name == "matmul":
if input_dtype == DType.INT8:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT48,
DType.FP32,
DType.FP16,
)
elif input_dtype == DType.INT16:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.FP32,
DType.FP16,
)
elif (
input_dtype == DType.FP32
or input_dtype == DType.FP16
or input_dtype == DType.BF16
):
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
)
elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
incorrect_types = (
DType.INT4,
DType.INT8,
DType.INT16,
DType.INT32,
DType.INT48,
DType.FP32,
DType.BF16,
)
else:
# Assume all types but the input type are incorrect
incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
return rng.choice(a=incorrect_types)
def get_rank_mismatch_shape(rng, output_shape):
"""
Extends the rank of the provided output_shape by
an arbitrary amount but ensures the total element
count remains the same.
"""
rank_modifier = rng.choice([1, 2, 3])
output_shape += [1] * rank_modifier
return output_shape
def float32_is_valid_bfloat16(f):
"""Return True if float value is valid bfloat16."""
f32_bits = get_float32_bitstring(f)
return f32_bits[16:] == "0" * 16
def float32_is_valid_float8(f):
"""Return True if float value is valid float8."""
f32_bits = get_float32_bitstring(f)
return f32_bits[8:] == "0" * 24
def get_float32_bitstring(f):
"""Return a big-endian string of bits representing a 32 bit float."""
f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
return f"{f32_bits_as_int:032b}"
def float32_to_bfloat16(f):
"""Turns fp32 value into bfloat16 by flooring.
Floors the least significant 16 bits of the input
fp32 value and returns this valid bfloat16 representation as fp32.
For simplicity during bit-wrangling, ignores underlying system
endianness and interprets as big-endian.
Returns a bf16-valid float following system's native byte order.
"""
f32_bits = get_float32_bitstring(f)
f32_floored_bits = f32_bits[:16] + "0" * 16
# Assume sys.byteorder matches system's underlying float byteorder
fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
return struct.unpack("@f", fp_bytes)[0] # native byteorder
def float32_to_fp8e4m3(f):
"""Turns fp32 value into fp8e4m3"""
f32_bits = get_float32_bitstring(f)
# TODO: needs src/generate and src/verify code ready
fp_bytes = int(f32_bits, 2).to_bytes(4, byteorder=sys.byteorder)
return struct.unpack("@f", fp_bytes)[0] # native byteorder
def float32_to_fp8e5m2(f):
"""Turns fp32 value into fp8e5m2"""
f32_bits = get_float32_bitstring(f)
# TODO: needs src/generate and src/verify code ready
fp_bytes = int(f32_bits, 2).to_bytes(4, byteorder=sys.byteorder)
return struct.unpack("@f", fp_bytes)[0]
vect_f32_to_bf16 = np.vectorize(
float32_to_bfloat16, otypes=(np.float32,)
) # NumPy vectorize: applies function to vector faster than looping
vect_f32_to_fp8e4m3 = np.vectorize(
float32_to_fp8e4m3, otypes=(np.float32,)
) # NumPy vectorize: applies function to vector faster than looping
vect_f32_to_fp8e5m2 = np.vectorize(
float32_to_fp8e5m2, otypes=(np.float32,)
) # Numpy vectorize: applies function to vector faster than looping