blob: cfe7cc6fd723ac8b14d144d676f34217d9898a02 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
evacha019c96eef2024-02-07 11:21:55 +000016# - fullset: precalculated number of possible values in the data type's range, equal to 2^width
Jeremy Johnson1271c442023-09-05 11:39:26 +010017# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010018DTYPE_ATTRIBUTES = {
evacha019c96eef2024-02-07 11:21:55 +000019 DType.BOOL: {"str": "b", "width": 1, "fullset": 2, "json": "BOOL"},
20 DType.INT4: {"str": "i4", "width": 4, "fullset": 16, "json": "INT4"},
21 DType.INT8: {"str": "i8", "width": 8, "fullset": 256, "json": "INT8"},
22 DType.UINT8: {"str": "u8", "width": 8, "fullset": 256, "json": "UINT8"},
23 DType.INT16: {"str": "i16", "width": 16, "fullset": 65536, "json": "INT16"},
24 DType.UINT16: {"str": "u16", "width": 16, "fullset": 65536, "json": "UINT16"},
25 DType.INT32: {"str": "i32", "width": 32, "fullset": 1 << 32, "json": "INT32"},
26 DType.INT48: {"str": "i48", "width": 48, "fullset": 1 << 48, "json": "INT48"},
27 DType.SHAPE: {"str": "s", "width": 64, "fullset": 1 << 64, "json": "SHAPE"},
28 DType.FP16: {"str": "f16", "width": 16, "fullset": 65536, "json": "FP16"},
29 DType.BF16: {"str": "bf16", "width": 16, "fullset": 65536, "json": "BF16"},
30 DType.FP32: {"str": "f32", "width": 32, "fullset": 1 << 32, "json": "FP32"},
31 DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "fullset": 256, "json": "FP8E4M3"},
32 DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "fullset": 256, "json": "FP8E5M2"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010033}
34
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035
Jeremy Johnson1271c442023-09-05 11:39:26 +010036class ComplianceMode(IntEnum):
37 """Compliance mode types."""
38
39 EXACT = 0
40 DOT_PRODUCT = 1
41 ULP = 2
42 FP_SPECIAL = 3
43 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000044 ABS_ERROR = 5
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000045 RELATIVE = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046
47
48class DataGenType(IntEnum):
49 """Data generator types."""
50
51 PSEUDO_RANDOM = 0
52 DOT_PRODUCT = 1
evacha019c96eef2024-02-07 11:21:55 +000053 BOUNDARY = 2
54 FULL_RANGE = 3
55 SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000056 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010057
58
Tai Ly6e1e2bc2024-03-01 20:59:32 +000059def dtypeWidth(dtype):
60 """Get the datatype width for data types"""
61 if dtype in DTYPE_ATTRIBUTES:
62 return DTYPE_ATTRIBUTES[dtype]["width"]
63 else:
64 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
65
66
Tai Ly60dc48c2024-03-08 22:19:41 +000067def dtypeIsFloat(dtype):
68 """Is floating point data type"""
69 return dtype in (DType.BF16, DType.FP16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2)
70
71
Jeremy Johnson65ba8092023-10-09 16:31:13 +010072def dtypeIsSupportedByCompliance(dtype):
73 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010074 if isinstance(dtype, list) or isinstance(dtype, tuple):
75 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000076 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010077
78
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010079def getOpNameFromOpListName(opName):
80 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
81 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
82 if opName.startswith(name):
83 return name
84 return opName
85
86
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087def valueToName(item, value):
88 """Get the name of an attribute with the given value.
89
90 This convenience function is needed to print meaningful names for
91 the values of the tosa.Op.Op and tosa.DType.DType classes.
92 This would not be necessary if they were subclasses of Enum, or
93 IntEnum, which, sadly, they are not.
94
95 Args:
96 item: The class, or object, to find the value in
97 value: The value to find
98
99 Example, to get the name of a DType value:
100
101 name = valueToName(DType, DType.INT8) # returns 'INT8'
102 name = valueToName(DType, 4) # returns 'INT8'
103
104 Returns:
105 The name of the first attribute found with a matching value,
106
107 Raises:
108 ValueError if the value is not found
109 """
110 for attr in dir(item):
111 if getattr(item, attr) == value:
112 return attr
113 raise ValueError(f"value ({value}) not found")
114
115
116def allDTypes(*, excludes=None):
117 """Get a set of all DType values, optionally excluding some values.
118
119 This convenience function is needed to provide a sequence of DType values.
120 This would be much easier if DType was a subclass of Enum, or IntEnum,
121 as we could then iterate over the values directly, instead of using
122 dir() to find the attributes and then check if they are what we want.
123
124 Args:
125 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
126
127 Returns:
128 A set of DType values
129 """
130 excludes = () if not excludes else excludes
131 return {
132 getattr(DType, t)
133 for t in dir(DType)
134 if not callable(getattr(DType, t))
135 and not t.startswith("__")
136 and getattr(DType, t) not in excludes
137 }
138
139
140def usableDTypes(*, excludes=None):
141 """Get a set of usable DType values, optionally excluding some values.
142
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100143 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
144 addition to the excludes specified by the caller, as the serializer lib
145 does not support them.
146 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
147 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100148
149 Args:
150 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
151
152 Returns:
153 A set of DType values
154 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100155 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100156 omit.update(excludes if excludes else ())
157 return allDTypes(excludes=omit)
158
159
160def product(shape):
161 value = 1
162 for n in shape:
163 value *= n
164 return value
James Ward8b390432022-08-12 20:48:56 +0100165
166
167def get_accum_dtype_from_tgTypes(dtypes):
168 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100169 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
170 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100171
172
173def get_wrong_output_type(op_name, rng, input_dtype):
174 if op_name == "fully_connected" or op_name == "matmul":
175 if input_dtype == DType.INT8:
176 incorrect_types = (
177 DType.INT4,
178 DType.INT8,
179 DType.INT16,
180 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100181 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100182 DType.FP16,
183 )
184 elif input_dtype == DType.INT16:
185 incorrect_types = (
186 DType.INT4,
187 DType.INT8,
188 DType.INT16,
189 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100190 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100191 DType.FP16,
192 )
James Ward24dbc422022-10-19 12:20:31 +0100193 elif (
194 input_dtype == DType.FP32
195 or input_dtype == DType.FP16
196 or input_dtype == DType.BF16
197 ):
James Ward8b390432022-08-12 20:48:56 +0100198 incorrect_types = (
199 DType.INT4,
200 DType.INT8,
201 DType.INT16,
202 DType.INT32,
203 DType.INT48,
204 )
Won Jeon2c34b462024-02-06 18:37:00 +0000205 elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
206 incorrect_types = (
207 DType.INT4,
208 DType.INT8,
209 DType.INT16,
210 DType.INT32,
211 DType.INT48,
212 DType.FP32,
213 DType.BF16,
214 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000215 else:
216 # Assume all types but the input type are incorrect
217 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100218 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100219
220
Luke Huttona4e48ca2023-02-22 11:53:48 +0000221def get_rank_mismatch_shape(rng, output_shape):
222 """
223 Extends the rank of the provided output_shape by
224 an arbitrary amount but ensures the total element
225 count remains the same.
226 """
227 rank_modifier = rng.choice([1, 2, 3])
228 output_shape += [1] * rank_modifier
229 return output_shape
230
231
James Ward24dbc422022-10-19 12:20:31 +0100232def float32_is_valid_bfloat16(f):
233 """Return True if float value is valid bfloat16."""
234 f32_bits = get_float32_bitstring(f)
235 return f32_bits[16:] == "0" * 16
236
237
Won Jeon2c34b462024-02-06 18:37:00 +0000238def float32_is_valid_float8(f):
239 """Return True if float value is valid float8."""
240 f32_bits = get_float32_bitstring(f)
241 return f32_bits[8:] == "0" * 24
242
243
James Ward24dbc422022-10-19 12:20:31 +0100244def get_float32_bitstring(f):
245 """Return a big-endian string of bits representing a 32 bit float."""
246 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
247 return f"{f32_bits_as_int:032b}"
248
249
250def float32_to_bfloat16(f):
251 """Turns fp32 value into bfloat16 by flooring.
252
253 Floors the least significant 16 bits of the input
254 fp32 value and returns this valid bfloat16 representation as fp32.
255 For simplicity during bit-wrangling, ignores underlying system
256 endianness and interprets as big-endian.
257 Returns a bf16-valid float following system's native byte order.
258 """
259 f32_bits = get_float32_bitstring(f)
260 f32_floored_bits = f32_bits[:16] + "0" * 16
261
262 # Assume sys.byteorder matches system's underlying float byteorder
263 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
264 return struct.unpack("@f", fp_bytes)[0] # native byteorder
265
266
Won Jeon2c34b462024-02-06 18:37:00 +0000267def float32_to_fp8e4m3(f):
268 """Turns fp32 value into fp8e4m3"""
269 f32_bits = get_float32_bitstring(f)
270 fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
271 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
272 return struct.unpack("@f", fp_bytes)[0] # native byteorder
273
274
275def float32_to_fp8e5m2(f):
276 """Turns fp32 value into fp8e5m2"""
277 f32_bits = get_float32_bitstring(f)
278 fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
279 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
280 return struct.unpack("@f", fp_bytes)[0]
281
282
James Ward24dbc422022-10-19 12:20:31 +0100283vect_f32_to_bf16 = np.vectorize(
284 float32_to_bfloat16, otypes=(np.float32,)
285) # NumPy vectorize: applies function to vector faster than looping
Won Jeon2c34b462024-02-06 18:37:00 +0000286
287vect_f32_to_fp8e4m3 = np.vectorize(
288 float32_to_fp8e4m3, otypes=(np.float32,)
289) # NumPy vectorize: applies function to vector faster than looping
290
291vect_f32_to_fp8e5m2 = np.vectorize(
292 float32_to_fp8e5m2, otypes=(np.float32,)
293) # Numpy vectorize: applies function to vector faster than looping