blob: 31a0ff0a0ee949a101594a341e60cae859e69275 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
16# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010017DTYPE_ATTRIBUTES = {
Jeremy Johnson1271c442023-09-05 11:39:26 +010018 DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
19 DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
20 DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
21 DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
22 DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
23 DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
24 DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
25 DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
26 DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
27 DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
28 DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
29 DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
Won Jeon2c34b462024-02-06 18:37:00 +000030 DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
31 DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010032}
33
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
Jeremy Johnson1271c442023-09-05 11:39:26 +010035class ComplianceMode(IntEnum):
36 """Compliance mode types."""
37
38 EXACT = 0
39 DOT_PRODUCT = 1
40 ULP = 2
41 FP_SPECIAL = 3
42 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000043 ABS_ERROR = 5
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000044 RELATIVE = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010045
46
47class DataGenType(IntEnum):
48 """Data generator types."""
49
50 PSEUDO_RANDOM = 0
51 DOT_PRODUCT = 1
52 OP_BOUNDARY = 2
53 OP_FULLSET = 3
54 OP_SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000055 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010056
57
Jeremy Johnson65ba8092023-10-09 16:31:13 +010058def dtypeIsSupportedByCompliance(dtype):
59 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010060 if isinstance(dtype, list) or isinstance(dtype, tuple):
61 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000062 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010063
64
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010065def getOpNameFromOpListName(opName):
66 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
67 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
68 if opName.startswith(name):
69 return name
70 return opName
71
72
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010073def valueToName(item, value):
74 """Get the name of an attribute with the given value.
75
76 This convenience function is needed to print meaningful names for
77 the values of the tosa.Op.Op and tosa.DType.DType classes.
78 This would not be necessary if they were subclasses of Enum, or
79 IntEnum, which, sadly, they are not.
80
81 Args:
82 item: The class, or object, to find the value in
83 value: The value to find
84
85 Example, to get the name of a DType value:
86
87 name = valueToName(DType, DType.INT8) # returns 'INT8'
88 name = valueToName(DType, 4) # returns 'INT8'
89
90 Returns:
91 The name of the first attribute found with a matching value,
92
93 Raises:
94 ValueError if the value is not found
95 """
96 for attr in dir(item):
97 if getattr(item, attr) == value:
98 return attr
99 raise ValueError(f"value ({value}) not found")
100
101
102def allDTypes(*, excludes=None):
103 """Get a set of all DType values, optionally excluding some values.
104
105 This convenience function is needed to provide a sequence of DType values.
106 This would be much easier if DType was a subclass of Enum, or IntEnum,
107 as we could then iterate over the values directly, instead of using
108 dir() to find the attributes and then check if they are what we want.
109
110 Args:
111 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
112
113 Returns:
114 A set of DType values
115 """
116 excludes = () if not excludes else excludes
117 return {
118 getattr(DType, t)
119 for t in dir(DType)
120 if not callable(getattr(DType, t))
121 and not t.startswith("__")
122 and getattr(DType, t) not in excludes
123 }
124
125
126def usableDTypes(*, excludes=None):
127 """Get a set of usable DType values, optionally excluding some values.
128
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100129 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
130 addition to the excludes specified by the caller, as the serializer lib
131 does not support them.
132 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
133 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100134
135 Args:
136 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
137
138 Returns:
139 A set of DType values
140 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100141 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100142 omit.update(excludes if excludes else ())
143 return allDTypes(excludes=omit)
144
145
146def product(shape):
147 value = 1
148 for n in shape:
149 value *= n
150 return value
James Ward8b390432022-08-12 20:48:56 +0100151
152
153def get_accum_dtype_from_tgTypes(dtypes):
154 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100155 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
156 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100157
158
159def get_wrong_output_type(op_name, rng, input_dtype):
160 if op_name == "fully_connected" or op_name == "matmul":
161 if input_dtype == DType.INT8:
162 incorrect_types = (
163 DType.INT4,
164 DType.INT8,
165 DType.INT16,
166 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100167 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100168 DType.FP16,
169 )
170 elif input_dtype == DType.INT16:
171 incorrect_types = (
172 DType.INT4,
173 DType.INT8,
174 DType.INT16,
175 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100176 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100177 DType.FP16,
178 )
James Ward24dbc422022-10-19 12:20:31 +0100179 elif (
180 input_dtype == DType.FP32
181 or input_dtype == DType.FP16
182 or input_dtype == DType.BF16
183 ):
James Ward8b390432022-08-12 20:48:56 +0100184 incorrect_types = (
185 DType.INT4,
186 DType.INT8,
187 DType.INT16,
188 DType.INT32,
189 DType.INT48,
190 )
Won Jeon2c34b462024-02-06 18:37:00 +0000191 elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
192 incorrect_types = (
193 DType.INT4,
194 DType.INT8,
195 DType.INT16,
196 DType.INT32,
197 DType.INT48,
198 DType.FP32,
199 DType.BF16,
200 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000201 else:
202 # Assume all types but the input type are incorrect
203 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100204 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100205
206
Luke Huttona4e48ca2023-02-22 11:53:48 +0000207def get_rank_mismatch_shape(rng, output_shape):
208 """
209 Extends the rank of the provided output_shape by
210 an arbitrary amount but ensures the total element
211 count remains the same.
212 """
213 rank_modifier = rng.choice([1, 2, 3])
214 output_shape += [1] * rank_modifier
215 return output_shape
216
217
James Ward24dbc422022-10-19 12:20:31 +0100218def float32_is_valid_bfloat16(f):
219 """Return True if float value is valid bfloat16."""
220 f32_bits = get_float32_bitstring(f)
221 return f32_bits[16:] == "0" * 16
222
223
Won Jeon2c34b462024-02-06 18:37:00 +0000224def float32_is_valid_float8(f):
225 """Return True if float value is valid float8."""
226 f32_bits = get_float32_bitstring(f)
227 return f32_bits[8:] == "0" * 24
228
229
James Ward24dbc422022-10-19 12:20:31 +0100230def get_float32_bitstring(f):
231 """Return a big-endian string of bits representing a 32 bit float."""
232 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
233 return f"{f32_bits_as_int:032b}"
234
235
236def float32_to_bfloat16(f):
237 """Turns fp32 value into bfloat16 by flooring.
238
239 Floors the least significant 16 bits of the input
240 fp32 value and returns this valid bfloat16 representation as fp32.
241 For simplicity during bit-wrangling, ignores underlying system
242 endianness and interprets as big-endian.
243 Returns a bf16-valid float following system's native byte order.
244 """
245 f32_bits = get_float32_bitstring(f)
246 f32_floored_bits = f32_bits[:16] + "0" * 16
247
248 # Assume sys.byteorder matches system's underlying float byteorder
249 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
250 return struct.unpack("@f", fp_bytes)[0] # native byteorder
251
252
Won Jeon2c34b462024-02-06 18:37:00 +0000253def float32_to_fp8e4m3(f):
254 """Turns fp32 value into fp8e4m3"""
255 f32_bits = get_float32_bitstring(f)
256 fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
257 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
258 return struct.unpack("@f", fp_bytes)[0] # native byteorder
259
260
261def float32_to_fp8e5m2(f):
262 """Turns fp32 value into fp8e5m2"""
263 f32_bits = get_float32_bitstring(f)
264 fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
265 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
266 return struct.unpack("@f", fp_bytes)[0]
267
268
James Ward24dbc422022-10-19 12:20:31 +0100269vect_f32_to_bf16 = np.vectorize(
270 float32_to_bfloat16, otypes=(np.float32,)
271) # NumPy vectorize: applies function to vector faster than looping
Won Jeon2c34b462024-02-06 18:37:00 +0000272
273vect_f32_to_fp8e4m3 = np.vectorize(
274 float32_to_fp8e4m3, otypes=(np.float32,)
275) # NumPy vectorize: applies function to vector faster than looping
276
277vect_f32_to_fp8e5m2 = np.vectorize(
278 float32_to_fp8e5m2, otypes=(np.float32,)
279) # Numpy vectorize: applies function to vector faster than looping