blob: 6387d068e70e29759d6c0dd8874ae367fd6b6946 [file] [log] [blame]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
16# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010017DTYPE_ATTRIBUTES = {
Jeremy Johnson1271c442023-09-05 11:39:26 +010018 DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
19 DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
20 DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
21 DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
22 DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
23 DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
24 DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
25 DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
26 DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
27 DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
28 DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
29 DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010030}
31
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010032
Jeremy Johnson1271c442023-09-05 11:39:26 +010033class ComplianceMode(IntEnum):
34 """Compliance mode types."""
35
36 EXACT = 0
37 DOT_PRODUCT = 1
38 ULP = 2
39 FP_SPECIAL = 3
40 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000041 ABS_ERROR = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010042
43
44class DataGenType(IntEnum):
45 """Data generator types."""
46
47 PSEUDO_RANDOM = 0
48 DOT_PRODUCT = 1
49 OP_BOUNDARY = 2
50 OP_FULLSET = 3
51 OP_SPECIAL = 4
Won Jeon64e4bfe2024-01-18 06:31:55 +000052 FIXED_DATA = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010053
54
Jeremy Johnson65ba8092023-10-09 16:31:13 +010055def dtypeIsSupportedByCompliance(dtype):
56 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010057 if isinstance(dtype, list) or isinstance(dtype, tuple):
58 dtype = dtype[0]
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +000059 return dtype in (DType.FP32, DType.FP16)
Jeremy Johnson1271c442023-09-05 11:39:26 +010060
61
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010062def getOpNameFromOpListName(opName):
63 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
64 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
65 if opName.startswith(name):
66 return name
67 return opName
68
69
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070def valueToName(item, value):
71 """Get the name of an attribute with the given value.
72
73 This convenience function is needed to print meaningful names for
74 the values of the tosa.Op.Op and tosa.DType.DType classes.
75 This would not be necessary if they were subclasses of Enum, or
76 IntEnum, which, sadly, they are not.
77
78 Args:
79 item: The class, or object, to find the value in
80 value: The value to find
81
82 Example, to get the name of a DType value:
83
84 name = valueToName(DType, DType.INT8) # returns 'INT8'
85 name = valueToName(DType, 4) # returns 'INT8'
86
87 Returns:
88 The name of the first attribute found with a matching value,
89
90 Raises:
91 ValueError if the value is not found
92 """
93 for attr in dir(item):
94 if getattr(item, attr) == value:
95 return attr
96 raise ValueError(f"value ({value}) not found")
97
98
99def allDTypes(*, excludes=None):
100 """Get a set of all DType values, optionally excluding some values.
101
102 This convenience function is needed to provide a sequence of DType values.
103 This would be much easier if DType was a subclass of Enum, or IntEnum,
104 as we could then iterate over the values directly, instead of using
105 dir() to find the attributes and then check if they are what we want.
106
107 Args:
108 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
109
110 Returns:
111 A set of DType values
112 """
113 excludes = () if not excludes else excludes
114 return {
115 getattr(DType, t)
116 for t in dir(DType)
117 if not callable(getattr(DType, t))
118 and not t.startswith("__")
119 and getattr(DType, t) not in excludes
120 }
121
122
123def usableDTypes(*, excludes=None):
124 """Get a set of usable DType values, optionally excluding some values.
125
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100126 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
127 addition to the excludes specified by the caller, as the serializer lib
128 does not support them.
129 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
130 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100131
132 Args:
133 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
134
135 Returns:
136 A set of DType values
137 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100138 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100139 omit.update(excludes if excludes else ())
140 return allDTypes(excludes=omit)
141
142
143def product(shape):
144 value = 1
145 for n in shape:
146 value *= n
147 return value
James Ward8b390432022-08-12 20:48:56 +0100148
149
150def get_accum_dtype_from_tgTypes(dtypes):
151 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100152 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
153 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100154
155
156def get_wrong_output_type(op_name, rng, input_dtype):
157 if op_name == "fully_connected" or op_name == "matmul":
158 if input_dtype == DType.INT8:
159 incorrect_types = (
160 DType.INT4,
161 DType.INT8,
162 DType.INT16,
163 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100164 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100165 DType.FP16,
166 )
167 elif input_dtype == DType.INT16:
168 incorrect_types = (
169 DType.INT4,
170 DType.INT8,
171 DType.INT16,
172 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100173 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100174 DType.FP16,
175 )
James Ward24dbc422022-10-19 12:20:31 +0100176 elif (
177 input_dtype == DType.FP32
178 or input_dtype == DType.FP16
179 or input_dtype == DType.BF16
180 ):
James Ward8b390432022-08-12 20:48:56 +0100181 incorrect_types = (
182 DType.INT4,
183 DType.INT8,
184 DType.INT16,
185 DType.INT32,
186 DType.INT48,
187 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000188 else:
189 # Assume all types but the input type are incorrect
190 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100191 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100192
193
Luke Huttona4e48ca2023-02-22 11:53:48 +0000194def get_rank_mismatch_shape(rng, output_shape):
195 """
196 Extends the rank of the provided output_shape by
197 an arbitrary amount but ensures the total element
198 count remains the same.
199 """
200 rank_modifier = rng.choice([1, 2, 3])
201 output_shape += [1] * rank_modifier
202 return output_shape
203
204
James Ward24dbc422022-10-19 12:20:31 +0100205def float32_is_valid_bfloat16(f):
206 """Return True if float value is valid bfloat16."""
207 f32_bits = get_float32_bitstring(f)
208 return f32_bits[16:] == "0" * 16
209
210
211def get_float32_bitstring(f):
212 """Return a big-endian string of bits representing a 32 bit float."""
213 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
214 return f"{f32_bits_as_int:032b}"
215
216
217def float32_to_bfloat16(f):
218 """Turns fp32 value into bfloat16 by flooring.
219
220 Floors the least significant 16 bits of the input
221 fp32 value and returns this valid bfloat16 representation as fp32.
222 For simplicity during bit-wrangling, ignores underlying system
223 endianness and interprets as big-endian.
224 Returns a bf16-valid float following system's native byte order.
225 """
226 f32_bits = get_float32_bitstring(f)
227 f32_floored_bits = f32_bits[:16] + "0" * 16
228
229 # Assume sys.byteorder matches system's underlying float byteorder
230 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
231 return struct.unpack("@f", fp_bytes)[0] # native byteorder
232
233
234vect_f32_to_bf16 = np.vectorize(
235 float32_to_bfloat16, otypes=(np.float32,)
236) # NumPy vectorize: applies function to vector faster than looping