blob: 384463ff83e4dadc9dfc91d93e48e0390aff7974 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
16# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010017DTYPE_ATTRIBUTES = {
Jeremy Johnson1271c442023-09-05 11:39:26 +010018 DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
19 DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
20 DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
21 DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
22 DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
23 DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
24 DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
25 DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
26 DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
27 DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
28 DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
29 DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
Won Jeon2c34b462024-02-06 18:37:00 +000030 DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
31 DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010032}
33
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
Jeremy Johnson1271c442023-09-05 11:39:26 +010035class ComplianceMode(IntEnum):
36 """Compliance mode types."""
37
38 EXACT = 0
39 DOT_PRODUCT = 1
40 ULP = 2
41 FP_SPECIAL = 3
42 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000043 ABS_ERROR = 5
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000044 RELATIVE = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010045
46
47class DataGenType(IntEnum):
48 """Data generator types."""
49
50 PSEUDO_RANDOM = 0
51 DOT_PRODUCT = 1
52 OP_BOUNDARY = 2
53 OP_FULLSET = 3
54 OP_SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000055 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010056
57
Tai Ly6e1e2bc2024-03-01 20:59:32 +000058def dtypeWidth(dtype):
59 """Get the datatype width for data types"""
60 if dtype in DTYPE_ATTRIBUTES:
61 return DTYPE_ATTRIBUTES[dtype]["width"]
62 else:
63 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
64
65
Jeremy Johnson65ba8092023-10-09 16:31:13 +010066def dtypeIsSupportedByCompliance(dtype):
67 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010068 if isinstance(dtype, list) or isinstance(dtype, tuple):
69 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000070 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010071
72
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010073def getOpNameFromOpListName(opName):
74 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
75 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
76 if opName.startswith(name):
77 return name
78 return opName
79
80
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010081def valueToName(item, value):
82 """Get the name of an attribute with the given value.
83
84 This convenience function is needed to print meaningful names for
85 the values of the tosa.Op.Op and tosa.DType.DType classes.
86 This would not be necessary if they were subclasses of Enum, or
87 IntEnum, which, sadly, they are not.
88
89 Args:
90 item: The class, or object, to find the value in
91 value: The value to find
92
93 Example, to get the name of a DType value:
94
95 name = valueToName(DType, DType.INT8) # returns 'INT8'
96 name = valueToName(DType, 4) # returns 'INT8'
97
98 Returns:
99 The name of the first attribute found with a matching value,
100
101 Raises:
102 ValueError if the value is not found
103 """
104 for attr in dir(item):
105 if getattr(item, attr) == value:
106 return attr
107 raise ValueError(f"value ({value}) not found")
108
109
110def allDTypes(*, excludes=None):
111 """Get a set of all DType values, optionally excluding some values.
112
113 This convenience function is needed to provide a sequence of DType values.
114 This would be much easier if DType was a subclass of Enum, or IntEnum,
115 as we could then iterate over the values directly, instead of using
116 dir() to find the attributes and then check if they are what we want.
117
118 Args:
119 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
120
121 Returns:
122 A set of DType values
123 """
124 excludes = () if not excludes else excludes
125 return {
126 getattr(DType, t)
127 for t in dir(DType)
128 if not callable(getattr(DType, t))
129 and not t.startswith("__")
130 and getattr(DType, t) not in excludes
131 }
132
133
134def usableDTypes(*, excludes=None):
135 """Get a set of usable DType values, optionally excluding some values.
136
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100137 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
138 addition to the excludes specified by the caller, as the serializer lib
139 does not support them.
140 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
141 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100142
143 Args:
144 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
145
146 Returns:
147 A set of DType values
148 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100149 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100150 omit.update(excludes if excludes else ())
151 return allDTypes(excludes=omit)
152
153
154def product(shape):
155 value = 1
156 for n in shape:
157 value *= n
158 return value
James Ward8b390432022-08-12 20:48:56 +0100159
160
161def get_accum_dtype_from_tgTypes(dtypes):
162 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100163 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
164 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100165
166
167def get_wrong_output_type(op_name, rng, input_dtype):
168 if op_name == "fully_connected" or op_name == "matmul":
169 if input_dtype == DType.INT8:
170 incorrect_types = (
171 DType.INT4,
172 DType.INT8,
173 DType.INT16,
174 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100175 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100176 DType.FP16,
177 )
178 elif input_dtype == DType.INT16:
179 incorrect_types = (
180 DType.INT4,
181 DType.INT8,
182 DType.INT16,
183 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100184 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100185 DType.FP16,
186 )
James Ward24dbc422022-10-19 12:20:31 +0100187 elif (
188 input_dtype == DType.FP32
189 or input_dtype == DType.FP16
190 or input_dtype == DType.BF16
191 ):
James Ward8b390432022-08-12 20:48:56 +0100192 incorrect_types = (
193 DType.INT4,
194 DType.INT8,
195 DType.INT16,
196 DType.INT32,
197 DType.INT48,
198 )
Won Jeon2c34b462024-02-06 18:37:00 +0000199 elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
200 incorrect_types = (
201 DType.INT4,
202 DType.INT8,
203 DType.INT16,
204 DType.INT32,
205 DType.INT48,
206 DType.FP32,
207 DType.BF16,
208 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000209 else:
210 # Assume all types but the input type are incorrect
211 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100212 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100213
214
Luke Huttona4e48ca2023-02-22 11:53:48 +0000215def get_rank_mismatch_shape(rng, output_shape):
216 """
217 Extends the rank of the provided output_shape by
218 an arbitrary amount but ensures the total element
219 count remains the same.
220 """
221 rank_modifier = rng.choice([1, 2, 3])
222 output_shape += [1] * rank_modifier
223 return output_shape
224
225
James Ward24dbc422022-10-19 12:20:31 +0100226def float32_is_valid_bfloat16(f):
227 """Return True if float value is valid bfloat16."""
228 f32_bits = get_float32_bitstring(f)
229 return f32_bits[16:] == "0" * 16
230
231
Won Jeon2c34b462024-02-06 18:37:00 +0000232def float32_is_valid_float8(f):
233 """Return True if float value is valid float8."""
234 f32_bits = get_float32_bitstring(f)
235 return f32_bits[8:] == "0" * 24
236
237
James Ward24dbc422022-10-19 12:20:31 +0100238def get_float32_bitstring(f):
239 """Return a big-endian string of bits representing a 32 bit float."""
240 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
241 return f"{f32_bits_as_int:032b}"
242
243
244def float32_to_bfloat16(f):
245 """Turns fp32 value into bfloat16 by flooring.
246
247 Floors the least significant 16 bits of the input
248 fp32 value and returns this valid bfloat16 representation as fp32.
249 For simplicity during bit-wrangling, ignores underlying system
250 endianness and interprets as big-endian.
251 Returns a bf16-valid float following system's native byte order.
252 """
253 f32_bits = get_float32_bitstring(f)
254 f32_floored_bits = f32_bits[:16] + "0" * 16
255
256 # Assume sys.byteorder matches system's underlying float byteorder
257 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
258 return struct.unpack("@f", fp_bytes)[0] # native byteorder
259
260
Won Jeon2c34b462024-02-06 18:37:00 +0000261def float32_to_fp8e4m3(f):
262 """Turns fp32 value into fp8e4m3"""
263 f32_bits = get_float32_bitstring(f)
264 fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
265 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
266 return struct.unpack("@f", fp_bytes)[0] # native byteorder
267
268
269def float32_to_fp8e5m2(f):
270 """Turns fp32 value into fp8e5m2"""
271 f32_bits = get_float32_bitstring(f)
272 fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
273 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
274 return struct.unpack("@f", fp_bytes)[0]
275
276
James Ward24dbc422022-10-19 12:20:31 +0100277vect_f32_to_bf16 = np.vectorize(
278 float32_to_bfloat16, otypes=(np.float32,)
279) # NumPy vectorize: applies function to vector faster than looping
Won Jeon2c34b462024-02-06 18:37:00 +0000280
281vect_f32_to_fp8e4m3 = np.vectorize(
282 float32_to_fp8e4m3, otypes=(np.float32,)
283) # NumPy vectorize: applies function to vector faster than looping
284
285vect_f32_to_fp8e5m2 = np.vectorize(
286 float32_to_fp8e5m2, otypes=(np.float32,)
287) # Numpy vectorize: applies function to vector faster than looping