blob: 318f29671099c0ae2cce9745fa868c4f58758348 [file] [log] [blame]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
16# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010017DTYPE_ATTRIBUTES = {
Jeremy Johnson1271c442023-09-05 11:39:26 +010018 DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
19 DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
20 DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
21 DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
22 DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
23 DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
24 DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
25 DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
26 DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
27 DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
28 DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
29 DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010030}
31
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010032
Jeremy Johnson1271c442023-09-05 11:39:26 +010033class ComplianceMode(IntEnum):
34 """Compliance mode types."""
35
36 EXACT = 0
37 DOT_PRODUCT = 1
38 ULP = 2
39 FP_SPECIAL = 3
40 REDUCE_PRODUCT = 4
Jeremy Johnson9a758382023-11-07 16:27:35 +000041 ABS_ERROR = 5
Jeremy Johnson1271c442023-09-05 11:39:26 +010042
43
44class DataGenType(IntEnum):
45 """Data generator types."""
46
47 PSEUDO_RANDOM = 0
48 DOT_PRODUCT = 1
49 OP_BOUNDARY = 2
50 OP_FULLSET = 3
51 OP_SPECIAL = 4
52
53
Jeremy Johnson65ba8092023-10-09 16:31:13 +010054def dtypeIsSupportedByCompliance(dtype):
55 """Types supported by the new data generation and compliance flow."""
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 if isinstance(dtype, list) or isinstance(dtype, tuple):
57 dtype = dtype[0]
Jeremy Johnson65ba8092023-10-09 16:31:13 +010058 return dtype in (DType.FP32,)
Jeremy Johnson1271c442023-09-05 11:39:26 +010059
60
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010061def getOpNameFromOpListName(opName):
62 """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
63 for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
64 if opName.startswith(name):
65 return name
66 return opName
67
68
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010069def valueToName(item, value):
70 """Get the name of an attribute with the given value.
71
72 This convenience function is needed to print meaningful names for
73 the values of the tosa.Op.Op and tosa.DType.DType classes.
74 This would not be necessary if they were subclasses of Enum, or
75 IntEnum, which, sadly, they are not.
76
77 Args:
78 item: The class, or object, to find the value in
79 value: The value to find
80
81 Example, to get the name of a DType value:
82
83 name = valueToName(DType, DType.INT8) # returns 'INT8'
84 name = valueToName(DType, 4) # returns 'INT8'
85
86 Returns:
87 The name of the first attribute found with a matching value,
88
89 Raises:
90 ValueError if the value is not found
91 """
92 for attr in dir(item):
93 if getattr(item, attr) == value:
94 return attr
95 raise ValueError(f"value ({value}) not found")
96
97
98def allDTypes(*, excludes=None):
99 """Get a set of all DType values, optionally excluding some values.
100
101 This convenience function is needed to provide a sequence of DType values.
102 This would be much easier if DType was a subclass of Enum, or IntEnum,
103 as we could then iterate over the values directly, instead of using
104 dir() to find the attributes and then check if they are what we want.
105
106 Args:
107 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
108
109 Returns:
110 A set of DType values
111 """
112 excludes = () if not excludes else excludes
113 return {
114 getattr(DType, t)
115 for t in dir(DType)
116 if not callable(getattr(DType, t))
117 and not t.startswith("__")
118 and getattr(DType, t) not in excludes
119 }
120
121
122def usableDTypes(*, excludes=None):
123 """Get a set of usable DType values, optionally excluding some values.
124
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100125 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
126 addition to the excludes specified by the caller, as the serializer lib
127 does not support them.
128 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
129 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100130
131 Args:
132 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
133
134 Returns:
135 A set of DType values
136 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100137 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100138 omit.update(excludes if excludes else ())
139 return allDTypes(excludes=omit)
140
141
142def product(shape):
143 value = 1
144 for n in shape:
145 value *= n
146 return value
James Ward8b390432022-08-12 20:48:56 +0100147
148
149def get_accum_dtype_from_tgTypes(dtypes):
150 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100151 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
152 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100153
154
155def get_wrong_output_type(op_name, rng, input_dtype):
156 if op_name == "fully_connected" or op_name == "matmul":
157 if input_dtype == DType.INT8:
158 incorrect_types = (
159 DType.INT4,
160 DType.INT8,
161 DType.INT16,
162 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100163 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100164 DType.FP16,
165 )
166 elif input_dtype == DType.INT16:
167 incorrect_types = (
168 DType.INT4,
169 DType.INT8,
170 DType.INT16,
171 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100172 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100173 DType.FP16,
174 )
James Ward24dbc422022-10-19 12:20:31 +0100175 elif (
176 input_dtype == DType.FP32
177 or input_dtype == DType.FP16
178 or input_dtype == DType.BF16
179 ):
James Ward8b390432022-08-12 20:48:56 +0100180 incorrect_types = (
181 DType.INT4,
182 DType.INT8,
183 DType.INT16,
184 DType.INT32,
185 DType.INT48,
186 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000187 else:
188 # Assume all types but the input type are incorrect
189 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100190 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100191
192
Luke Huttona4e48ca2023-02-22 11:53:48 +0000193def get_rank_mismatch_shape(rng, output_shape):
194 """
195 Extends the rank of the provided output_shape by
196 an arbitrary amount but ensures the total element
197 count remains the same.
198 """
199 rank_modifier = rng.choice([1, 2, 3])
200 output_shape += [1] * rank_modifier
201 return output_shape
202
203
James Ward24dbc422022-10-19 12:20:31 +0100204def float32_is_valid_bfloat16(f):
205 """Return True if float value is valid bfloat16."""
206 f32_bits = get_float32_bitstring(f)
207 return f32_bits[16:] == "0" * 16
208
209
210def get_float32_bitstring(f):
211 """Return a big-endian string of bits representing a 32 bit float."""
212 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
213 return f"{f32_bits_as_int:032b}"
214
215
216def float32_to_bfloat16(f):
217 """Turns fp32 value into bfloat16 by flooring.
218
219 Floors the least significant 16 bits of the input
220 fp32 value and returns this valid bfloat16 representation as fp32.
221 For simplicity during bit-wrangling, ignores underlying system
222 endianness and interprets as big-endian.
223 Returns a bf16-valid float following system's native byte order.
224 """
225 f32_bits = get_float32_bitstring(f)
226 f32_floored_bits = f32_bits[:16] + "0" * 16
227
228 # Assume sys.byteorder matches system's underlying float byteorder
229 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
230 return struct.unpack("@f", fp_bytes)[0] # native byteorder
231
232
233vect_f32_to_bf16 = np.vectorize(
234 float32_to_bfloat16, otypes=(np.float32,)
235) # NumPy vectorize: applies function to vector faster than looping