blob: 75a0df52e1e12b84dfc85025be32ea328fe68929 [file] [log] [blame]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
Jeremy Johnson1271c442023-09-05 11:39:26 +01005from enum import IntEnum
James Ward24dbc422022-10-19 12:20:31 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010010# Maximum dimension size for output and inputs for RESIZE
11MAX_RESIZE_DIMENSION = 16384
12
Jeremy Johnson1271c442023-09-05 11:39:26 +010013# Data type information dictionary
14# - str: filename abbreviation
15# - width: number of bytes needed for type
16# - json: JSON type string
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010017DTYPE_ATTRIBUTES = {
Jeremy Johnson1271c442023-09-05 11:39:26 +010018 DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
19 DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
20 DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
21 DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
22 DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
23 DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
24 DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
25 DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
26 DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
27 DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
28 DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
29 DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010030}
31
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010032
Jeremy Johnson1271c442023-09-05 11:39:26 +010033class ComplianceMode(IntEnum):
34 """Compliance mode types."""
35
36 EXACT = 0
37 DOT_PRODUCT = 1
38 ULP = 2
39 FP_SPECIAL = 3
40 REDUCE_PRODUCT = 4
41
42
43class DataGenType(IntEnum):
44 """Data generator types."""
45
46 PSEUDO_RANDOM = 0
47 DOT_PRODUCT = 1
48 OP_BOUNDARY = 2
49 OP_FULLSET = 3
50 OP_SPECIAL = 4
51
52
53# Additional (optional) data for dot product data generator
54DG_DOT_PRODUCT_OPTIONAL_INFO = ("acc_type", "kernel", "axis")
55
56
57def dtypeIsFloat(dtype):
58 return dtype in (DType.FP16, DType.BF16, DType.FP32)
59
60
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061def valueToName(item, value):
62 """Get the name of an attribute with the given value.
63
64 This convenience function is needed to print meaningful names for
65 the values of the tosa.Op.Op and tosa.DType.DType classes.
66 This would not be necessary if they were subclasses of Enum, or
67 IntEnum, which, sadly, they are not.
68
69 Args:
70 item: The class, or object, to find the value in
71 value: The value to find
72
73 Example, to get the name of a DType value:
74
75 name = valueToName(DType, DType.INT8) # returns 'INT8'
76 name = valueToName(DType, 4) # returns 'INT8'
77
78 Returns:
79 The name of the first attribute found with a matching value,
80
81 Raises:
82 ValueError if the value is not found
83 """
84 for attr in dir(item):
85 if getattr(item, attr) == value:
86 return attr
87 raise ValueError(f"value ({value}) not found")
88
89
90def allDTypes(*, excludes=None):
91 """Get a set of all DType values, optionally excluding some values.
92
93 This convenience function is needed to provide a sequence of DType values.
94 This would be much easier if DType was a subclass of Enum, or IntEnum,
95 as we could then iterate over the values directly, instead of using
96 dir() to find the attributes and then check if they are what we want.
97
98 Args:
99 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
100
101 Returns:
102 A set of DType values
103 """
104 excludes = () if not excludes else excludes
105 return {
106 getattr(DType, t)
107 for t in dir(DType)
108 if not callable(getattr(DType, t))
109 and not t.startswith("__")
110 and getattr(DType, t) not in excludes
111 }
112
113
114def usableDTypes(*, excludes=None):
115 """Get a set of usable DType values, optionally excluding some values.
116
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100117 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
118 addition to the excludes specified by the caller, as the serializer lib
119 does not support them.
120 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
121 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100122
123 Args:
124 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
125
126 Returns:
127 A set of DType values
128 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +0100129 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100130 omit.update(excludes if excludes else ())
131 return allDTypes(excludes=omit)
132
133
134def product(shape):
135 value = 1
136 for n in shape:
137 value *= n
138 return value
James Ward8b390432022-08-12 20:48:56 +0100139
140
141def get_accum_dtype_from_tgTypes(dtypes):
142 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100143 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
144 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100145
146
147def get_wrong_output_type(op_name, rng, input_dtype):
148 if op_name == "fully_connected" or op_name == "matmul":
149 if input_dtype == DType.INT8:
150 incorrect_types = (
151 DType.INT4,
152 DType.INT8,
153 DType.INT16,
154 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100155 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100156 DType.FP16,
157 )
158 elif input_dtype == DType.INT16:
159 incorrect_types = (
160 DType.INT4,
161 DType.INT8,
162 DType.INT16,
163 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100164 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100165 DType.FP16,
166 )
James Ward24dbc422022-10-19 12:20:31 +0100167 elif (
168 input_dtype == DType.FP32
169 or input_dtype == DType.FP16
170 or input_dtype == DType.BF16
171 ):
James Ward8b390432022-08-12 20:48:56 +0100172 incorrect_types = (
173 DType.INT4,
174 DType.INT8,
175 DType.INT16,
176 DType.INT32,
177 DType.INT48,
178 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000179 else:
180 # Assume all types but the input type are incorrect
181 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100182 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100183
184
Luke Huttona4e48ca2023-02-22 11:53:48 +0000185def get_rank_mismatch_shape(rng, output_shape):
186 """
187 Extends the rank of the provided output_shape by
188 an arbitrary amount but ensures the total element
189 count remains the same.
190 """
191 rank_modifier = rng.choice([1, 2, 3])
192 output_shape += [1] * rank_modifier
193 return output_shape
194
195
James Ward24dbc422022-10-19 12:20:31 +0100196def float32_is_valid_bfloat16(f):
197 """Return True if float value is valid bfloat16."""
198 f32_bits = get_float32_bitstring(f)
199 return f32_bits[16:] == "0" * 16
200
201
202def get_float32_bitstring(f):
203 """Return a big-endian string of bits representing a 32 bit float."""
204 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
205 return f"{f32_bits_as_int:032b}"
206
207
208def float32_to_bfloat16(f):
209 """Turns fp32 value into bfloat16 by flooring.
210
211 Floors the least significant 16 bits of the input
212 fp32 value and returns this valid bfloat16 representation as fp32.
213 For simplicity during bit-wrangling, ignores underlying system
214 endianness and interprets as big-endian.
215 Returns a bf16-valid float following system's native byte order.
216 """
217 f32_bits = get_float32_bitstring(f)
218 f32_floored_bits = f32_bits[:16] + "0" * 16
219
220 # Assume sys.byteorder matches system's underlying float byteorder
221 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
222 return struct.unpack("@f", fp_bytes)[0] # native byteorder
223
224
225vect_f32_to_bf16 = np.vectorize(
226 float32_to_bfloat16, otypes=(np.float32,)
227) # NumPy vectorize: applies function to vector faster than looping