blob: a8e321eab7801a780546556c7bb089480d49baa8 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnsondd975b82024-02-28 17:29:13 +000013# Maximum rank of tensor supported by test generator.
14MAX_TENSOR_RANK = 6
15
Jeremy Johnson1271c442023-09-05 11:39:26 +010016# Data type information dictionary
17# - str: filename abbreviation
18# - width: number of bytes needed for type
evacha019c96eef2024-02-07 11:21:55 +000019# - fullset: precalculated number of possible values in the data type's range, equal to 2^width
Jeremy Johnson1271c442023-09-05 11:39:26 +010020# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010021DTYPE_ATTRIBUTES = {
evacha019c96eef2024-02-07 11:21:55 +000022 DType.BOOL: {"str": "b", "width": 1, "fullset": 2, "json": "BOOL"},
23 DType.INT4: {"str": "i4", "width": 4, "fullset": 16, "json": "INT4"},
24 DType.INT8: {"str": "i8", "width": 8, "fullset": 256, "json": "INT8"},
25 DType.UINT8: {"str": "u8", "width": 8, "fullset": 256, "json": "UINT8"},
26 DType.INT16: {"str": "i16", "width": 16, "fullset": 65536, "json": "INT16"},
27 DType.UINT16: {"str": "u16", "width": 16, "fullset": 65536, "json": "UINT16"},
28 DType.INT32: {"str": "i32", "width": 32, "fullset": 1 << 32, "json": "INT32"},
29 DType.INT48: {"str": "i48", "width": 48, "fullset": 1 << 48, "json": "INT48"},
30 DType.SHAPE: {"str": "s", "width": 64, "fullset": 1 << 64, "json": "SHAPE"},
31 DType.FP16: {"str": "f16", "width": 16, "fullset": 65536, "json": "FP16"},
32 DType.BF16: {"str": "bf16", "width": 16, "fullset": 65536, "json": "BF16"},
33 DType.FP32: {"str": "f32", "width": 32, "fullset": 1 << 32, "json": "FP32"},
34 DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "fullset": 256, "json": "FP8E4M3"},
35 DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "fullset": 256, "json": "FP8E5M2"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010036}
37
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039class ComplianceMode(IntEnum):
40 """Compliance mode types."""
41
42 EXACT = 0
43 DOT_PRODUCT = 1
44 ULP = 2
45 FP_SPECIAL = 3
46 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000047 ABS_ERROR = 5
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000048 RELATIVE = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010049
50
51class DataGenType(IntEnum):
52 """Data generator types."""
53
54 PSEUDO_RANDOM = 0
55 DOT_PRODUCT = 1
evacha019c96eef2024-02-07 11:21:55 +000056 BOUNDARY = 2
57 FULL_RANGE = 3
58 SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000059 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010060
61
Tai Ly6e1e2bc2024-03-01 20:59:32 +000062def dtypeWidth(dtype):
63 """Get the datatype width for data types"""
64 if dtype in DTYPE_ATTRIBUTES:
65 return DTYPE_ATTRIBUTES[dtype]["width"]
66 else:
67 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
68
69
Tai Ly60dc48c2024-03-08 22:19:41 +000070def dtypeIsFloat(dtype):
71 """Is floating point data type"""
72 return dtype in (DType.BF16, DType.FP16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2)
73
74
Jeremy Johnson65ba8092023-10-09 16:31:13 +010075def dtypeIsSupportedByCompliance(dtype):
76 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010077 if isinstance(dtype, list) or isinstance(dtype, tuple):
78 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000079 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010080
81
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010082def getOpNameFromOpListName(opName):
83 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
84 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
85 if opName.startswith(name):
86 return name
87 return opName
88
89
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010090def valueToName(item, value):
91 """Get the name of an attribute with the given value.
92
93 This convenience function is needed to print meaningful names for
94 the values of the tosa.Op.Op and tosa.DType.DType classes.
95 This would not be necessary if they were subclasses of Enum, or
96 IntEnum, which, sadly, they are not.
97
98 Args:
99 item: The class, or object, to find the value in
100 value: The value to find
101
102 Example, to get the name of a DType value:
103
104 name = valueToName(DType, DType.INT8) # returns 'INT8'
105 name = valueToName(DType, 4) # returns 'INT8'
106
107 Returns:
108 The name of the first attribute found with a matching value,
109
110 Raises:
111 ValueError if the value is not found
112 """
113 for attr in dir(item):
114 if getattr(item, attr) == value:
115 return attr
116 raise ValueError(f"value ({value}) not found")
117
118
119def allDTypes(*, excludes=None):
120 """Get a set of all DType values, optionally excluding some values.
121
122 This convenience function is needed to provide a sequence of DType values.
123 This would be much easier if DType was a subclass of Enum, or IntEnum,
124 as we could then iterate over the values directly, instead of using
125 dir() to find the attributes and then check if they are what we want.
126
127 Args:
128 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
129
130 Returns:
131 A set of DType values
132 """
133 excludes = () if not excludes else excludes
134 return {
135 getattr(DType, t)
136 for t in dir(DType)
137 if not callable(getattr(DType, t))
138 and not t.startswith("__")
139 and getattr(DType, t) not in excludes
140 }
141
142
143def usableDTypes(*, excludes=None):
144 """Get a set of usable DType values, optionally excluding some values.
145
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100146 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
147 addition to the excludes specified by the caller, as the serializer lib
148 does not support them.
149 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
150 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100151
152 Args:
153 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
154
155 Returns:
156 A set of DType values
157 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100158 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100159 omit.update(excludes if excludes else ())
160 return allDTypes(excludes=omit)
161
162
163def product(shape):
164 value = 1
165 for n in shape:
166 value *= n
167 return value
James Ward8b390432022-08-12 20:48:56 +0100168
169
Tai Lyf36f2562024-03-14 16:21:29 +0000170def get_accum_dtypes_from_tgTypes(dtypes):
171 # Get accumulate data-types from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100172 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
Tai Lyf36f2562024-03-14 16:21:29 +0000173 input_dtype = dtypes[0]
174 output_dtype = dtypes[-1]
175 # by default, accum_dtypes contains only output_dtype
176 accum_dtypes = [output_dtype]
177 if input_dtype == DType.FP16 and output_dtype == DType.FP16:
178 accum_dtypes = [DType.FP16, DType.FP32]
179 elif output_dtype == DType.BF16:
180 accum_dtypes = [DType.FP32]
181 return accum_dtypes
James Ward8b390432022-08-12 20:48:56 +0100182
183
184def get_wrong_output_type(op_name, rng, input_dtype):
185 if op_name == "fully_connected" or op_name == "matmul":
186 if input_dtype == DType.INT8:
187 incorrect_types = (
188 DType.INT4,
189 DType.INT8,
190 DType.INT16,
191 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100192 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100193 DType.FP16,
194 )
195 elif input_dtype == DType.INT16:
196 incorrect_types = (
197 DType.INT4,
198 DType.INT8,
199 DType.INT16,
200 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100201 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100202 DType.FP16,
203 )
James Ward24dbc422022-10-19 12:20:31 +0100204 elif (
205 input_dtype == DType.FP32
206 or input_dtype == DType.FP16
207 or input_dtype == DType.BF16
208 ):
James Ward8b390432022-08-12 20:48:56 +0100209 incorrect_types = (
210 DType.INT4,
211 DType.INT8,
212 DType.INT16,
213 DType.INT32,
214 DType.INT48,
215 )
Won Jeon2c34b462024-02-06 18:37:00 +0000216 elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
217 incorrect_types = (
218 DType.INT4,
219 DType.INT8,
220 DType.INT16,
221 DType.INT32,
222 DType.INT48,
223 DType.FP32,
224 DType.BF16,
225 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000226 else:
227 # Assume all types but the input type are incorrect
228 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100229 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100230
231
Luke Huttona4e48ca2023-02-22 11:53:48 +0000232def get_rank_mismatch_shape(rng, output_shape):
233 """
234 Extends the rank of the provided output_shape by
235 an arbitrary amount but ensures the total element
236 count remains the same.
237 """
238 rank_modifier = rng.choice([1, 2, 3])
239 output_shape += [1] * rank_modifier
240 return output_shape
241
242
James Ward24dbc422022-10-19 12:20:31 +0100243def float32_is_valid_bfloat16(f):
244 """Return True if float value is valid bfloat16."""
245 f32_bits = get_float32_bitstring(f)
246 return f32_bits[16:] == "0" * 16
247
248
Won Jeon2c34b462024-02-06 18:37:00 +0000249def float32_is_valid_float8(f):
250 """Return True if float value is valid float8."""
251 f32_bits = get_float32_bitstring(f)
252 return f32_bits[8:] == "0" * 24
253
254
James Ward24dbc422022-10-19 12:20:31 +0100255def get_float32_bitstring(f):
256 """Return a big-endian string of bits representing a 32 bit float."""
257 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
258 return f"{f32_bits_as_int:032b}"
259
260
261def float32_to_bfloat16(f):
262 """Turns fp32 value into bfloat16 by flooring.
263
264 Floors the least significant 16 bits of the input
265 fp32 value and returns this valid bfloat16 representation as fp32.
266 For simplicity during bit-wrangling, ignores underlying system
267 endianness and interprets as big-endian.
268 Returns a bf16-valid float following system's native byte order.
269 """
270 f32_bits = get_float32_bitstring(f)
271 f32_floored_bits = f32_bits[:16] + "0" * 16
272
273 # Assume sys.byteorder matches system's underlying float byteorder
274 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
275 return struct.unpack("@f", fp_bytes)[0] # native byteorder
276
277
Won Jeon2c34b462024-02-06 18:37:00 +0000278def float32_to_fp8e4m3(f):
279 """Turns fp32 value into fp8e4m3"""
280 f32_bits = get_float32_bitstring(f)
281 fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
282 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
283 return struct.unpack("@f", fp_bytes)[0] # native byteorder
284
285
286def float32_to_fp8e5m2(f):
287 """Turns fp32 value into fp8e5m2"""
288 f32_bits = get_float32_bitstring(f)
289 fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
290 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
291 return struct.unpack("@f", fp_bytes)[0]
292
293
James Ward24dbc422022-10-19 12:20:31 +0100294vect_f32_to_bf16 = np.vectorize(
295 float32_to_bfloat16, otypes=(np.float32,)
296) # NumPy vectorize: applies function to vector faster than looping
Won Jeon2c34b462024-02-06 18:37:00 +0000297
298vect_f32_to_fp8e4m3 = np.vectorize(
299 float32_to_fp8e4m3, otypes=(np.float32,)
300) # NumPy vectorize: applies function to vector faster than looping
301
302vect_f32_to_fp8e5m2 = np.vectorize(
303 float32_to_fp8e5m2, otypes=(np.float32,)
304) # Numpy vectorize: applies function to vector faster than looping