blob: 6558bf8917056e9ec160676469a2b606908a8e23 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
evacha019c96eef2024-02-07 11:21:55 +000016# - fullset: precalculated number of possible values in the data type's range, equal to 2^width
Jeremy Johnson1271c442023-09-05 11:39:26 +010017# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010018DTYPE_ATTRIBUTES = {
evacha019c96eef2024-02-07 11:21:55 +000019 DType.BOOL: {"str": "b", "width": 1, "fullset": 2, "json": "BOOL"},
20 DType.INT4: {"str": "i4", "width": 4, "fullset": 16, "json": "INT4"},
21 DType.INT8: {"str": "i8", "width": 8, "fullset": 256, "json": "INT8"},
22 DType.UINT8: {"str": "u8", "width": 8, "fullset": 256, "json": "UINT8"},
23 DType.INT16: {"str": "i16", "width": 16, "fullset": 65536, "json": "INT16"},
24 DType.UINT16: {"str": "u16", "width": 16, "fullset": 65536, "json": "UINT16"},
25 DType.INT32: {"str": "i32", "width": 32, "fullset": 1 << 32, "json": "INT32"},
26 DType.INT48: {"str": "i48", "width": 48, "fullset": 1 << 48, "json": "INT48"},
27 DType.SHAPE: {"str": "s", "width": 64, "fullset": 1 << 64, "json": "SHAPE"},
28 DType.FP16: {"str": "f16", "width": 16, "fullset": 65536, "json": "FP16"},
29 DType.BF16: {"str": "bf16", "width": 16, "fullset": 65536, "json": "BF16"},
30 DType.FP32: {"str": "f32", "width": 32, "fullset": 1 << 32, "json": "FP32"},
31 DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "fullset": 256, "json": "FP8E4M3"},
32 DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "fullset": 256, "json": "FP8E5M2"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010033}
34
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035
Jeremy Johnson1271c442023-09-05 11:39:26 +010036class ComplianceMode(IntEnum):
37 """Compliance mode types."""
38
39 EXACT = 0
40 DOT_PRODUCT = 1
41 ULP = 2
42 FP_SPECIAL = 3
43 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000044 ABS_ERROR = 5
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000045 RELATIVE = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046
47
48class DataGenType(IntEnum):
49 """Data generator types."""
50
51 PSEUDO_RANDOM = 0
52 DOT_PRODUCT = 1
evacha019c96eef2024-02-07 11:21:55 +000053 BOUNDARY = 2
54 FULL_RANGE = 3
55 SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000056 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010057
58
Tai Ly6e1e2bc2024-03-01 20:59:32 +000059def dtypeWidth(dtype):
60 """Get the datatype width for data types"""
61 if dtype in DTYPE_ATTRIBUTES:
62 return DTYPE_ATTRIBUTES[dtype]["width"]
63 else:
64 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
65
66
Jeremy Johnson65ba8092023-10-09 16:31:13 +010067def dtypeIsSupportedByCompliance(dtype):
68 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010069 if isinstance(dtype, list) or isinstance(dtype, tuple):
70 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000071 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010072
73
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010074def getOpNameFromOpListName(opName):
75 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
76 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
77 if opName.startswith(name):
78 return name
79 return opName
80
81
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010082def valueToName(item, value):
83 """Get the name of an attribute with the given value.
84
85 This convenience function is needed to print meaningful names for
86 the values of the tosa.Op.Op and tosa.DType.DType classes.
87 This would not be necessary if they were subclasses of Enum, or
88 IntEnum, which, sadly, they are not.
89
90 Args:
91 item: The class, or object, to find the value in
92 value: The value to find
93
94 Example, to get the name of a DType value:
95
96 name = valueToName(DType, DType.INT8) # returns 'INT8'
97 name = valueToName(DType, 4) # returns 'INT8'
98
99 Returns:
100 The name of the first attribute found with a matching value,
101
102 Raises:
103 ValueError if the value is not found
104 """
105 for attr in dir(item):
106 if getattr(item, attr) == value:
107 return attr
108 raise ValueError(f"value ({value}) not found")
109
110
111def allDTypes(*, excludes=None):
112 """Get a set of all DType values, optionally excluding some values.
113
114 This convenience function is needed to provide a sequence of DType values.
115 This would be much easier if DType was a subclass of Enum, or IntEnum,
116 as we could then iterate over the values directly, instead of using
117 dir() to find the attributes and then check if they are what we want.
118
119 Args:
120 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
121
122 Returns:
123 A set of DType values
124 """
125 excludes = () if not excludes else excludes
126 return {
127 getattr(DType, t)
128 for t in dir(DType)
129 if not callable(getattr(DType, t))
130 and not t.startswith("__")
131 and getattr(DType, t) not in excludes
132 }
133
134
135def usableDTypes(*, excludes=None):
136 """Get a set of usable DType values, optionally excluding some values.
137
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100138 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
139 addition to the excludes specified by the caller, as the serializer lib
140 does not support them.
141 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
142 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100143
144 Args:
145 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
146
147 Returns:
148 A set of DType values
149 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100150 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100151 omit.update(excludes if excludes else ())
152 return allDTypes(excludes=omit)
153
154
155def product(shape):
156 value = 1
157 for n in shape:
158 value *= n
159 return value
James Ward8b390432022-08-12 20:48:56 +0100160
161
162def get_accum_dtype_from_tgTypes(dtypes):
163 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100164 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
165 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100166
167
168def get_wrong_output_type(op_name, rng, input_dtype):
169 if op_name == "fully_connected" or op_name == "matmul":
170 if input_dtype == DType.INT8:
171 incorrect_types = (
172 DType.INT4,
173 DType.INT8,
174 DType.INT16,
175 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100176 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100177 DType.FP16,
178 )
179 elif input_dtype == DType.INT16:
180 incorrect_types = (
181 DType.INT4,
182 DType.INT8,
183 DType.INT16,
184 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100185 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100186 DType.FP16,
187 )
James Ward24dbc422022-10-19 12:20:31 +0100188 elif (
189 input_dtype == DType.FP32
190 or input_dtype == DType.FP16
191 or input_dtype == DType.BF16
192 ):
James Ward8b390432022-08-12 20:48:56 +0100193 incorrect_types = (
194 DType.INT4,
195 DType.INT8,
196 DType.INT16,
197 DType.INT32,
198 DType.INT48,
199 )
Won Jeon2c34b462024-02-06 18:37:00 +0000200 elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
201 incorrect_types = (
202 DType.INT4,
203 DType.INT8,
204 DType.INT16,
205 DType.INT32,
206 DType.INT48,
207 DType.FP32,
208 DType.BF16,
209 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000210 else:
211 # Assume all types but the input type are incorrect
212 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100213 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100214
215
Luke Huttona4e48ca2023-02-22 11:53:48 +0000216def get_rank_mismatch_shape(rng, output_shape):
217 """
218 Extends the rank of the provided output_shape by
219 an arbitrary amount but ensures the total element
220 count remains the same.
221 """
222 rank_modifier = rng.choice([1, 2, 3])
223 output_shape += [1] * rank_modifier
224 return output_shape
225
226
James Ward24dbc422022-10-19 12:20:31 +0100227def float32_is_valid_bfloat16(f):
228 """Return True if float value is valid bfloat16."""
229 f32_bits = get_float32_bitstring(f)
230 return f32_bits[16:] == "0" * 16
231
232
Won Jeon2c34b462024-02-06 18:37:00 +0000233def float32_is_valid_float8(f):
234 """Return True if float value is valid float8."""
235 f32_bits = get_float32_bitstring(f)
236 return f32_bits[8:] == "0" * 24
237
238
James Ward24dbc422022-10-19 12:20:31 +0100239def get_float32_bitstring(f):
240 """Return a big-endian string of bits representing a 32 bit float."""
241 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
242 return f"{f32_bits_as_int:032b}"
243
244
245def float32_to_bfloat16(f):
246 """Turns fp32 value into bfloat16 by flooring.
247
248 Floors the least significant 16 bits of the input
249 fp32 value and returns this valid bfloat16 representation as fp32.
250 For simplicity during bit-wrangling, ignores underlying system
251 endianness and interprets as big-endian.
252 Returns a bf16-valid float following system's native byte order.
253 """
254 f32_bits = get_float32_bitstring(f)
255 f32_floored_bits = f32_bits[:16] + "0" * 16
256
257 # Assume sys.byteorder matches system's underlying float byteorder
258 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
259 return struct.unpack("@f", fp_bytes)[0] # native byteorder
260
261
Won Jeon2c34b462024-02-06 18:37:00 +0000262def float32_to_fp8e4m3(f):
263 """Turns fp32 value into fp8e4m3"""
264 f32_bits = get_float32_bitstring(f)
265 fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
266 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
267 return struct.unpack("@f", fp_bytes)[0] # native byteorder
268
269
270def float32_to_fp8e5m2(f):
271 """Turns fp32 value into fp8e5m2"""
272 f32_bits = get_float32_bitstring(f)
273 fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
274 fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
275 return struct.unpack("@f", fp_bytes)[0]
276
277
James Ward24dbc422022-10-19 12:20:31 +0100278vect_f32_to_bf16 = np.vectorize(
279 float32_to_bfloat16, otypes=(np.float32,)
280) # NumPy vectorize: applies function to vector faster than looping
Won Jeon2c34b462024-02-06 18:37:00 +0000281
282vect_f32_to_fp8e4m3 = np.vectorize(
283 float32_to_fp8e4m3, otypes=(np.float32,)
284) # NumPy vectorize: applies function to vector faster than looping
285
286vect_f32_to_fp8e5m2 = np.vectorize(
287 float32_to_fp8e5m2, otypes=(np.float32,)
288) # Numpy vectorize: applies function to vector faster than looping