blob: 3b487defdae0c3c6c88a506bfc195403c4d4ea51 [file] [log] [blame]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
16# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010017DTYPE_ATTRIBUTES = {
Jeremy Johnson1271c442023-09-05 11:39:26 +010018 DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
19 DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
20 DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
21 DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
22 DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
23 DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
24 DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
25 DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
26 DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
27 DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
28 DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
29 DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010030}
31
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010032
Jeremy Johnson1271c442023-09-05 11:39:26 +010033class ComplianceMode(IntEnum):
34 """Compliance mode types."""
35
36 EXACT = 0
37 DOT_PRODUCT = 1
38 ULP = 2
39 FP_SPECIAL = 3
40 REDUCE_PRODUCT = 4
41
42
43class DataGenType(IntEnum):
44 """Data generator types."""
45
46 PSEUDO_RANDOM = 0
47 DOT_PRODUCT = 1
48 OP_BOUNDARY = 2
49 OP_FULLSET = 3
50 OP_SPECIAL = 4
51
52
Jeremy Johnson65ba8092023-10-09 16:31:13 +010053def dtypeIsSupportedByCompliance(dtype):
54 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010055 if isinstance(dtype, list) or isinstance(dtype, tuple):
56 dtype = dtype[0]
Jeremy Johnson65ba8092023-10-09 16:31:13 +010057 return dtype in (DType.FP32,)
Jeremy Johnson1271c442023-09-05 11:39:26 +010058
59
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010060def getOpNameFromOpListName(opName):
61 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
62 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
63 if opName.startswith(name):
64 return name
65 return opName
66
67
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068def valueToName(item, value):
69 """Get the name of an attribute with the given value.
70
71 This convenience function is needed to print meaningful names for
72 the values of the tosa.Op.Op and tosa.DType.DType classes.
73 This would not be necessary if they were subclasses of Enum, or
74 IntEnum, which, sadly, they are not.
75
76 Args:
77 item: The class, or object, to find the value in
78 value: The value to find
79
80 Example, to get the name of a DType value:
81
82 name = valueToName(DType, DType.INT8) # returns 'INT8'
83 name = valueToName(DType, 4) # returns 'INT8'
84
85 Returns:
86 The name of the first attribute found with a matching value,
87
88 Raises:
89 ValueError if the value is not found
90 """
91 for attr in dir(item):
92 if getattr(item, attr) == value:
93 return attr
94 raise ValueError(f"value ({value}) not found")
95
96
97def allDTypes(*, excludes=None):
98 """Get a set of all DType values, optionally excluding some values.
99
100 This convenience function is needed to provide a sequence of DType values.
101 This would be much easier if DType was a subclass of Enum, or IntEnum,
102 as we could then iterate over the values directly, instead of using
103 dir() to find the attributes and then check if they are what we want.
104
105 Args:
106 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
107
108 Returns:
109 A set of DType values
110 """
111 excludes = () if not excludes else excludes
112 return {
113 getattr(DType, t)
114 for t in dir(DType)
115 if not callable(getattr(DType, t))
116 and not t.startswith("__")
117 and getattr(DType, t) not in excludes
118 }
119
120
121def usableDTypes(*, excludes=None):
122 """Get a set of usable DType values, optionally excluding some values.
123
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100124 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
125 addition to the excludes specified by the caller, as the serializer lib
126 does not support them.
127 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
128 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100129
130 Args:
131 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
132
133 Returns:
134 A set of DType values
135 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100136 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100137 omit.update(excludes if excludes else ())
138 return allDTypes(excludes=omit)
139
140
141def product(shape):
142 value = 1
143 for n in shape:
144 value *= n
145 return value
James Ward8b390432022-08-12 20:48:56 +0100146
147
148def get_accum_dtype_from_tgTypes(dtypes):
149 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100150 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
151 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100152
153
154def get_wrong_output_type(op_name, rng, input_dtype):
155 if op_name == "fully_connected" or op_name == "matmul":
156 if input_dtype == DType.INT8:
157 incorrect_types = (
158 DType.INT4,
159 DType.INT8,
160 DType.INT16,
161 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100162 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100163 DType.FP16,
164 )
165 elif input_dtype == DType.INT16:
166 incorrect_types = (
167 DType.INT4,
168 DType.INT8,
169 DType.INT16,
170 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100171 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100172 DType.FP16,
173 )
James Ward24dbc422022-10-19 12:20:31 +0100174 elif (
175 input_dtype == DType.FP32
176 or input_dtype == DType.FP16
177 or input_dtype == DType.BF16
178 ):
James Ward8b390432022-08-12 20:48:56 +0100179 incorrect_types = (
180 DType.INT4,
181 DType.INT8,
182 DType.INT16,
183 DType.INT32,
184 DType.INT48,
185 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000186 else:
187 # Assume all types but the input type are incorrect
188 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100189 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100190
191
Luke Huttona4e48ca2023-02-22 11:53:48 +0000192def get_rank_mismatch_shape(rng, output_shape):
193 """
194 Extends the rank of the provided output_shape by
195 an arbitrary amount but ensures the total element
196 count remains the same.
197 """
198 rank_modifier = rng.choice([1, 2, 3])
199 output_shape += [1] * rank_modifier
200 return output_shape
201
202
James Ward24dbc422022-10-19 12:20:31 +0100203def float32_is_valid_bfloat16(f):
204 """Return True if float value is valid bfloat16."""
205 f32_bits = get_float32_bitstring(f)
206 return f32_bits[16:] == "0" * 16
207
208
209def get_float32_bitstring(f):
210 """Return a big-endian string of bits representing a 32 bit float."""
211 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
212 return f"{f32_bits_as_int:032b}"
213
214
215def float32_to_bfloat16(f):
216 """Turns fp32 value into bfloat16 by flooring.
217
218 Floors the least significant 16 bits of the input
219 fp32 value and returns this valid bfloat16 representation as fp32.
220 For simplicity during bit-wrangling, ignores underlying system
221 endianness and interprets as big-endian.
222 Returns a bf16-valid float following system's native byte order.
223 """
224 f32_bits = get_float32_bitstring(f)
225 f32_floored_bits = f32_bits[:16] + "0" * 16
226
227 # Assume sys.byteorder matches system's underlying float byteorder
228 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
229 return struct.unpack("@f", fp_bytes)[0] # native byteorder
230
231
232vect_f32_to_bf16 = np.vectorize(
233 float32_to_bfloat16, otypes=(np.float32,)
234) # NumPy vectorize: applies function to vector faster than looping