blob: 76e738899408d2beea1489f656eaef7bf8299551 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
16# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010017DTYPE_ATTRIBUTES = {
Jeremy Johnson1271c442023-09-05 11:39:26 +010018 DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
19 DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
20 DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
21 DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
22 DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
23 DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
24 DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
25 DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
26 DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
27 DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
28 DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
29 DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010030}
31
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010032
Jeremy Johnson1271c442023-09-05 11:39:26 +010033class ComplianceMode(IntEnum):
34 """Compliance mode types."""
35
36 EXACT = 0
37 DOT_PRODUCT = 1
38 ULP = 2
39 FP_SPECIAL = 3
40 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000041 ABS_ERROR = 5
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000042 RELATIVE = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010043
44
45class DataGenType(IntEnum):
46 """Data generator types."""
47
48 PSEUDO_RANDOM = 0
49 DOT_PRODUCT = 1
50 OP_BOUNDARY = 2
51 OP_FULLSET = 3
52 OP_SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000053 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010054
55
Jeremy Johnson65ba8092023-10-09 16:31:13 +010056def dtypeIsSupportedByCompliance(dtype):
57 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010058 if isinstance(dtype, list) or isinstance(dtype, tuple):
59 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000060 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010061
62
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010063def getOpNameFromOpListName(opName):
64 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
65 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
66 if opName.startswith(name):
67 return name
68 return opName
69
70
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071def valueToName(item, value):
72 """Get the name of an attribute with the given value.
73
74 This convenience function is needed to print meaningful names for
75 the values of the tosa.Op.Op and tosa.DType.DType classes.
76 This would not be necessary if they were subclasses of Enum, or
77 IntEnum, which, sadly, they are not.
78
79 Args:
80 item: The class, or object, to find the value in
81 value: The value to find
82
83 Example, to get the name of a DType value:
84
85 name = valueToName(DType, DType.INT8) # returns 'INT8'
86 name = valueToName(DType, 4) # returns 'INT8'
87
88 Returns:
89 The name of the first attribute found with a matching value,
90
91 Raises:
92 ValueError if the value is not found
93 """
94 for attr in dir(item):
95 if getattr(item, attr) == value:
96 return attr
97 raise ValueError(f"value ({value}) not found")
98
99
100def allDTypes(*, excludes=None):
101 """Get a set of all DType values, optionally excluding some values.
102
103 This convenience function is needed to provide a sequence of DType values.
104 This would be much easier if DType was a subclass of Enum, or IntEnum,
105 as we could then iterate over the values directly, instead of using
106 dir() to find the attributes and then check if they are what we want.
107
108 Args:
109 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
110
111 Returns:
112 A set of DType values
113 """
114 excludes = () if not excludes else excludes
115 return {
116 getattr(DType, t)
117 for t in dir(DType)
118 if not callable(getattr(DType, t))
119 and not t.startswith("__")
120 and getattr(DType, t) not in excludes
121 }
122
123
124def usableDTypes(*, excludes=None):
125 """Get a set of usable DType values, optionally excluding some values.
126
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100127 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
128 addition to the excludes specified by the caller, as the serializer lib
129 does not support them.
130 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
131 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100132
133 Args:
134 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
135
136 Returns:
137 A set of DType values
138 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100139 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140 omit.update(excludes if excludes else ())
141 return allDTypes(excludes=omit)
142
143
144def product(shape):
145 value = 1
146 for n in shape:
147 value *= n
148 return value
James Ward8b390432022-08-12 20:48:56 +0100149
150
151def get_accum_dtype_from_tgTypes(dtypes):
152 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
154 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100155
156
157def get_wrong_output_type(op_name, rng, input_dtype):
158 if op_name == "fully_connected" or op_name == "matmul":
159 if input_dtype == DType.INT8:
160 incorrect_types = (
161 DType.INT4,
162 DType.INT8,
163 DType.INT16,
164 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100165 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100166 DType.FP16,
167 )
168 elif input_dtype == DType.INT16:
169 incorrect_types = (
170 DType.INT4,
171 DType.INT8,
172 DType.INT16,
173 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100174 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100175 DType.FP16,
176 )
James Ward24dbc422022-10-19 12:20:31 +0100177 elif (
178 input_dtype == DType.FP32
179 or input_dtype == DType.FP16
180 or input_dtype == DType.BF16
181 ):
James Ward8b390432022-08-12 20:48:56 +0100182 incorrect_types = (
183 DType.INT4,
184 DType.INT8,
185 DType.INT16,
186 DType.INT32,
187 DType.INT48,
188 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000189 else:
190 # Assume all types but the input type are incorrect
191 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100192 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100193
194
Luke Huttona4e48ca2023-02-22 11:53:48 +0000195def get_rank_mismatch_shape(rng, output_shape):
196 """
197 Extends the rank of the provided output_shape by
198 an arbitrary amount but ensures the total element
199 count remains the same.
200 """
201 rank_modifier = rng.choice([1, 2, 3])
202 output_shape += [1] * rank_modifier
203 return output_shape
204
205
James Ward24dbc422022-10-19 12:20:31 +0100206def float32_is_valid_bfloat16(f):
207 """Return True if float value is valid bfloat16."""
208 f32_bits = get_float32_bitstring(f)
209 return f32_bits[16:] == "0" * 16
210
211
212def get_float32_bitstring(f):
213 """Return a big-endian string of bits representing a 32 bit float."""
214 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
215 return f"{f32_bits_as_int:032b}"
216
217
218def float32_to_bfloat16(f):
219 """Turns fp32 value into bfloat16 by flooring.
220
221 Floors the least significant 16 bits of the input
222 fp32 value and returns this valid bfloat16 representation as fp32.
223 For simplicity during bit-wrangling, ignores underlying system
224 endianness and interprets as big-endian.
225 Returns a bf16-valid float following system's native byte order.
226 """
227 f32_bits = get_float32_bitstring(f)
228 f32_floored_bits = f32_bits[:16] + "0" * 16
229
230 # Assume sys.byteorder matches system's underlying float byteorder
231 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
232 return struct.unpack("@f", fp_bytes)[0] # native byteorder
233
234
235vect_f32_to_bf16 = np.vectorize(
236 float32_to_bfloat16, otypes=(np.float32,)
237) # NumPy vectorize: applies function to vector faster than looping