blob: 4a4f6bb151f0dceb0e36ea2d4e61afa8eed86973 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
evacha019c96eef2024-02-07 11:21:55 +000016# - fullset: precalculated number of possible values in the data type's range, equal to 2^width
Jeremy Johnson1271c442023-09-05 11:39:26 +010017# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010018DTYPE_ATTRIBUTES = {
evacha019c96eef2024-02-07 11:21:55 +000019 DType.BOOL: {"str": "b", "width": 1, "fullset": 2, "json": "BOOL"},
20 DType.INT4: {"str": "i4", "width": 4, "fullset": 16, "json": "INT4"},
21 DType.INT8: {"str": "i8", "width": 8, "fullset": 256, "json": "INT8"},
22 DType.UINT8: {"str": "u8", "width": 8, "fullset": 256, "json": "UINT8"},
23 DType.INT16: {"str": "i16", "width": 16, "fullset": 65536, "json": "INT16"},
24 DType.UINT16: {"str": "u16", "width": 16, "fullset": 65536, "json": "UINT16"},
25 DType.INT32: {"str": "i32", "width": 32, "fullset": 1 << 32, "json": "INT32"},
26 DType.INT48: {"str": "i48", "width": 48, "fullset": 1 << 48, "json": "INT48"},
27 DType.SHAPE: {"str": "s", "width": 64, "fullset": 1 << 64, "json": "SHAPE"},
28 DType.FP16: {"str": "f16", "width": 16, "fullset": 65536, "json": "FP16"},
29 DType.BF16: {"str": "bf16", "width": 16, "fullset": 65536, "json": "BF16"},
30 DType.FP32: {"str": "f32", "width": 32, "fullset": 1 << 32, "json": "FP32"},
31 DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "fullset": 256, "json": "FP8E4M3"},
32 DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "fullset": 256, "json": "FP8E5M2"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010033}
34
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035
Jeremy Johnson1271c442023-09-05 11:39:26 +010036class ComplianceMode(IntEnum):
37 """Compliance mode types."""
38
39 EXACT = 0
40 DOT_PRODUCT = 1
41 ULP = 2
42 FP_SPECIAL = 3
43 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000044 ABS_ERROR = 5
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000045 RELATIVE = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046
47
48class DataGenType(IntEnum):
49 """Data generator types."""
50
51 PSEUDO_RANDOM = 0
52 DOT_PRODUCT = 1
evacha019c96eef2024-02-07 11:21:55 +000053 BOUNDARY = 2
54 FULL_RANGE = 3
55 SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000056 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010057
58
Tai Ly6e1e2bc2024-03-01 20:59:32 +000059def dtypeWidth(dtype):
60 """Get the datatype width for data types"""
61 if dtype in DTYPE_ATTRIBUTES:
62 return DTYPE_ATTRIBUTES[dtype]["width"]
63 else:
64 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
65
66
Tai Ly60dc48c2024-03-08 22:19:41 +000067def dtypeIsFloat(dtype):
68 """Is floating point data type"""
69 return dtype in (DType.BF16, DType.FP16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2)
70
71
Jeremy Johnson65ba8092023-10-09 16:31:13 +010072def dtypeIsSupportedByCompliance(dtype):
73 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010074 if isinstance(dtype, list) or isinstance(dtype, tuple):
75 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000076 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010077
78
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010079def getOpNameFromOpListName(opName):
80 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
81 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
82 if opName.startswith(name):
83 return name
84 return opName
85
86
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087def valueToName(item, value):
88 """Get the name of an attribute with the given value.
89
90 This convenience function is needed to print meaningful names for
91 the values of the tosa.Op.Op and tosa.DType.DType classes.
92 This would not be necessary if they were subclasses of Enum, or
93 IntEnum, which, sadly, they are not.
94
95 Args:
96 item: The class, or object, to find the value in
97 value: The value to find
98
99 Example, to get the name of a DType value:
100
101 name = valueToName(DType, DType.INT8) # returns 'INT8'
102 name = valueToName(DType, 4) # returns 'INT8'
103
104 Returns:
105 The name of the first attribute found with a matching value,
106
107 Raises:
108 ValueError if the value is not found
109 """
110 for attr in dir(item):
111 if getattr(item, attr) == value:
112 return attr
113 raise ValueError(f"value ({value}) not found")
114
115
116def allDTypes(*, excludes=None):
117 """Get a set of all DType values, optionally excluding some values.
118
119 This convenience function is needed to provide a sequence of DType values.
120 This would be much easier if DType was a subclass of Enum, or IntEnum,
121 as we could then iterate over the values directly, instead of using
122 dir() to find the attributes and then check if they are what we want.
123
124 Args:
125 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
126
127 Returns:
128 A set of DType values
129 """
130 excludes = () if not excludes else excludes
131 return {
132 getattr(DType, t)
133 for t in dir(DType)
134 if not callable(getattr(DType, t))
135 and not t.startswith("__")
136 and getattr(DType, t) not in excludes
137 }
138
139
140def usableDTypes(*, excludes=None):
141 """Get a set of usable DType values, optionally excluding some values.
142
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100143 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
144 addition to the excludes specified by the caller, as the serializer lib
145 does not support them.
146 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
147 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100148
149 Args:
150 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
151
152 Returns:
153 A set of DType values
154 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100155 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100156 omit.update(excludes if excludes else ())
157 return allDTypes(excludes=omit)
158
159
160def product(shape):
161 value = 1
162 for n in shape:
163 value *= n
164 return value
James Ward8b390432022-08-12 20:48:56 +0100165
166
Tai Lyf36f2562024-03-14 16:21:29 +0000167def get_accum_dtypes_from_tgTypes(dtypes):
168 # Get accumulate data-types from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100169 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
Tai Lyf36f2562024-03-14 16:21:29 +0000170 input_dtype = dtypes[0]
171 output_dtype = dtypes[-1]
172 # by default, accum_dtypes contains only output_dtype
173 accum_dtypes = [output_dtype]
174 if input_dtype == DType.FP16 and output_dtype == DType.FP16:
175 accum_dtypes = [DType.FP16, DType.FP32]
176 elif output_dtype == DType.BF16:
177 accum_dtypes = [DType.FP32]
178 return accum_dtypes
James Ward8b390432022-08-12 20:48:56 +0100179
180
181def get_wrong_output_type(op_name, rng, input_dtype):
182 if op_name == "fully_connected" or op_name == "matmul":
183 if input_dtype == DType.INT8:
184 incorrect_types = (
185 DType.INT4,
186 DType.INT8,
187 DType.INT16,
188 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100189 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100190 DType.FP16,
191 )
192 elif input_dtype == DType.INT16:
193 incorrect_types = (
194 DType.INT4,
195 DType.INT8,
196 DType.INT16,
197 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100198 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100199 DType.FP16,
200 )
James Ward24dbc422022-10-19 12:20:31 +0100201 elif (
202 input_dtype == DType.FP32
203 or input_dtype == DType.FP16
204 or input_dtype == DType.BF16
205 ):
James Ward8b390432022-08-12 20:48:56 +0100206 incorrect_types = (
207 DType.INT4,
208 DType.INT8,
209 DType.INT16,
210 DType.INT32,
211 DType.INT48,
212 )
Won Jeon2c34b462024-02-06 18:37:00 +0000213 elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
214 incorrect_types = (
215 DType.INT4,
216 DType.INT8,
217 DType.INT16,
218 DType.INT32,
219 DType.INT48,
220 DType.FP32,
221 DType.BF16,
222 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000223 else:
224 # Assume all types but the input type are incorrect
225 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100226 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100227
228
Luke Huttona4e48ca2023-02-22 11:53:48 +0000229def get_rank_mismatch_shape(rng, output_shape):
230 """
231 Extends the rank of the provided output_shape by
232 an arbitrary amount but ensures the total element
233 count remains the same.
234 """
235 rank_modifier = rng.choice([1, 2, 3])
236 output_shape += [1] * rank_modifier
237 return output_shape
238
239
James Ward24dbc422022-10-19 12:20:31 +0100240def float32_is_valid_bfloat16(f):
241 """Return True if float value is valid bfloat16."""
242 f32_bits = get_float32_bitstring(f)
243 return f32_bits[16:] == "0" * 16
244
245
Won Jeon2c34b462024-02-06 18:37:00 +0000246def float32_is_valid_float8(f):
247 """Return True if float value is valid float8."""
248 f32_bits = get_float32_bitstring(f)
249 return f32_bits[8:] == "0" * 24
250
251
James Ward24dbc422022-10-19 12:20:31 +0100252def get_float32_bitstring(f):
253 """Return a big-endian string of bits representing a 32 bit float."""
254 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
255 return f"{f32_bits_as_int:032b}"
256
257
258def float32_to_bfloat16(f):
259 """Turns fp32 value into bfloat16 by flooring.
260
261 Floors the least significant 16 bits of the input
262 fp32 value and returns this valid bfloat16 representation as fp32.
263 For simplicity during bit-wrangling, ignores underlying system
264 endianness and interprets as big-endian.
265 Returns a bf16-valid float following system's native byte order.
266 """
267 f32_bits = get_float32_bitstring(f)
268 f32_floored_bits = f32_bits[:16] + "0" * 16
269
270 # Assume sys.byteorder matches system's underlying float byteorder
271 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
272 return struct.unpack("@f", fp_bytes)[0] # native byteorder
273
274
Won Jeon2c34b462024-02-06 18:37:00 +0000275def float32_to_fp8e4m3(f):
276 """Turns fp32 value into fp8e4m3"""
277 f32_bits = get_float32_bitstring(f)
278 fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
279 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
280 return struct.unpack("@f", fp_bytes)[0] # native byteorder
281
282
283def float32_to_fp8e5m2(f):
284 """Turns fp32 value into fp8e5m2"""
285 f32_bits = get_float32_bitstring(f)
286 fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
287 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
288 return struct.unpack("@f", fp_bytes)[0]
289
290
James Ward24dbc422022-10-19 12:20:31 +0100291vect_f32_to_bf16 = np.vectorize(
292 float32_to_bfloat16, otypes=(np.float32,)
293) # NumPy vectorize: applies function to vector faster than looping
Won Jeon2c34b462024-02-06 18:37:00 +0000294
295vect_f32_to_fp8e4m3 = np.vectorize(
296 float32_to_fp8e4m3, otypes=(np.float32,)
297) # NumPy vectorize: applies function to vector faster than looping
298
299vect_f32_to_fp8e5m2 = np.vectorize(
300 float32_to_fp8e5m2, otypes=(np.float32,)
301) # Numpy vectorize: applies function to vector faster than looping