blob: 8ff62f1e8aeadeabd96140be50ebde86abfd8054 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
5
6import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from tosa.DType import DType
8
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01009# Maximum dimension size for output and inputs for RESIZE
10MAX_RESIZE_DIMENSION = 16384
11
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010012DTYPE_ATTRIBUTES = {
13 DType.BOOL: {"str": "b", "width": 1},
14 DType.INT4: {"str": "i4", "width": 4},
15 DType.INT8: {"str": "i8", "width": 8},
16 DType.UINT8: {"str": "u8", "width": 8},
17 DType.INT16: {"str": "i16", "width": 16},
18 DType.UINT16: {"str": "u16", "width": 16},
19 DType.INT32: {"str": "i32", "width": 32},
20 DType.INT48: {"str": "i48", "width": 48},
21 DType.FP16: {"str": "f16", "width": 16},
James Ward24dbc422022-10-19 12:20:31 +010022 DType.BF16: {"str": "bf16", "width": 16},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010023 DType.FP32: {"str": "f32", "width": 32},
24}
25
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010026
27def valueToName(item, value):
28 """Get the name of an attribute with the given value.
29
30 This convenience function is needed to print meaningful names for
31 the values of the tosa.Op.Op and tosa.DType.DType classes.
32 This would not be necessary if they were subclasses of Enum, or
33 IntEnum, which, sadly, they are not.
34
35 Args:
36 item: The class, or object, to find the value in
37 value: The value to find
38
39 Example, to get the name of a DType value:
40
41 name = valueToName(DType, DType.INT8) # returns 'INT8'
42 name = valueToName(DType, 4) # returns 'INT8'
43
44 Returns:
45 The name of the first attribute found with a matching value,
46
47 Raises:
48 ValueError if the value is not found
49 """
50 for attr in dir(item):
51 if getattr(item, attr) == value:
52 return attr
53 raise ValueError(f"value ({value}) not found")
54
55
56def allDTypes(*, excludes=None):
57 """Get a set of all DType values, optionally excluding some values.
58
59 This convenience function is needed to provide a sequence of DType values.
60 This would be much easier if DType was a subclass of Enum, or IntEnum,
61 as we could then iterate over the values directly, instead of using
62 dir() to find the attributes and then check if they are what we want.
63
64 Args:
65 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
66
67 Returns:
68 A set of DType values
69 """
70 excludes = () if not excludes else excludes
71 return {
72 getattr(DType, t)
73 for t in dir(DType)
74 if not callable(getattr(DType, t))
75 and not t.startswith("__")
76 and getattr(DType, t) not in excludes
77 }
78
79
80def usableDTypes(*, excludes=None):
81 """Get a set of usable DType values, optionally excluding some values.
82
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010083 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
84 addition to the excludes specified by the caller, as the serializer lib
85 does not support them.
86 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
87 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088
89 Args:
90 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
91
92 Returns:
93 A set of DType values
94 """
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010095 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 omit.update(excludes if excludes else ())
97 return allDTypes(excludes=omit)
98
99
100def product(shape):
101 value = 1
102 for n in shape:
103 value *= n
104 return value
James Ward8b390432022-08-12 20:48:56 +0100105
106
107def get_accum_dtype_from_tgTypes(dtypes):
108 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100109 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
110 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100111
112
113def get_wrong_output_type(op_name, rng, input_dtype):
114 if op_name == "fully_connected" or op_name == "matmul":
115 if input_dtype == DType.INT8:
116 incorrect_types = (
117 DType.INT4,
118 DType.INT8,
119 DType.INT16,
120 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100121 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100122 DType.FP16,
123 )
124 elif input_dtype == DType.INT16:
125 incorrect_types = (
126 DType.INT4,
127 DType.INT8,
128 DType.INT16,
129 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100130 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100131 DType.FP16,
132 )
James Ward24dbc422022-10-19 12:20:31 +0100133 elif (
134 input_dtype == DType.FP32
135 or input_dtype == DType.FP16
136 or input_dtype == DType.BF16
137 ):
James Ward8b390432022-08-12 20:48:56 +0100138 incorrect_types = (
139 DType.INT4,
140 DType.INT8,
141 DType.INT16,
142 DType.INT32,
143 DType.INT48,
144 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000145 else:
146 # Assume all types but the input type are incorrect
147 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100148 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100149
150
Luke Huttona4e48ca2023-02-22 11:53:48 +0000151def get_rank_mismatch_shape(rng, output_shape):
152 """
153 Extends the rank of the provided output_shape by
154 an arbitrary amount but ensures the total element
155 count remains the same.
156 """
157 rank_modifier = rng.choice([1, 2, 3])
158 output_shape += [1] * rank_modifier
159 return output_shape
160
161
James Ward24dbc422022-10-19 12:20:31 +0100162def float32_is_valid_bfloat16(f):
163 """Return True if float value is valid bfloat16."""
164 f32_bits = get_float32_bitstring(f)
165 return f32_bits[16:] == "0" * 16
166
167
168def get_float32_bitstring(f):
169 """Return a big-endian string of bits representing a 32 bit float."""
170 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
171 return f"{f32_bits_as_int:032b}"
172
173
174def float32_to_bfloat16(f):
175 """Turns fp32 value into bfloat16 by flooring.
176
177 Floors the least significant 16 bits of the input
178 fp32 value and returns this valid bfloat16 representation as fp32.
179 For simplicity during bit-wrangling, ignores underlying system
180 endianness and interprets as big-endian.
181 Returns a bf16-valid float following system's native byte order.
182 """
183 f32_bits = get_float32_bitstring(f)
184 f32_floored_bits = f32_bits[:16] + "0" * 16
185
186 # Assume sys.byteorder matches system's underlying float byteorder
187 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
188 return struct.unpack("@f", fp_bytes)[0] # native byteorder
189
190
191vect_f32_to_bf16 = np.vectorize(
192 float32_to_bfloat16, otypes=(np.float32,)
193) # NumPy vectorize: applies function to vector faster than looping