blob: d79ab3cb3145562dcf4e817286c39cdd9594a5ec [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
5
6import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from tosa.DType import DType
8
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01009# Maximum dimension size for output and inputs for RESIZE
10MAX_RESIZE_DIMENSION = 16384
11
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010012DTYPE_ATTRIBUTES = {
13 DType.BOOL: {"str": "b", "width": 1},
14 DType.INT4: {"str": "i4", "width": 4},
15 DType.INT8: {"str": "i8", "width": 8},
16 DType.UINT8: {"str": "u8", "width": 8},
17 DType.INT16: {"str": "i16", "width": 16},
18 DType.UINT16: {"str": "u16", "width": 16},
19 DType.INT32: {"str": "i32", "width": 32},
20 DType.INT48: {"str": "i48", "width": 48},
21 DType.FP16: {"str": "f16", "width": 16},
James Ward24dbc422022-10-19 12:20:31 +010022 DType.BF16: {"str": "bf16", "width": 16},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010023 DType.FP32: {"str": "f32", "width": 32},
24}
25
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010026
27def valueToName(item, value):
28 """Get the name of an attribute with the given value.
29
30 This convenience function is needed to print meaningful names for
31 the values of the tosa.Op.Op and tosa.DType.DType classes.
32 This would not be necessary if they were subclasses of Enum, or
33 IntEnum, which, sadly, they are not.
34
35 Args:
36 item: The class, or object, to find the value in
37 value: The value to find
38
39 Example, to get the name of a DType value:
40
41 name = valueToName(DType, DType.INT8) # returns 'INT8'
42 name = valueToName(DType, 4) # returns 'INT8'
43
44 Returns:
45 The name of the first attribute found with a matching value,
46
47 Raises:
48 ValueError if the value is not found
49 """
50 for attr in dir(item):
51 if getattr(item, attr) == value:
52 return attr
53 raise ValueError(f"value ({value}) not found")
54
55
56def allDTypes(*, excludes=None):
57 """Get a set of all DType values, optionally excluding some values.
58
59 This convenience function is needed to provide a sequence of DType values.
60 This would be much easier if DType was a subclass of Enum, or IntEnum,
61 as we could then iterate over the values directly, instead of using
62 dir() to find the attributes and then check if they are what we want.
63
64 Args:
65 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
66
67 Returns:
68 A set of DType values
69 """
70 excludes = () if not excludes else excludes
71 return {
72 getattr(DType, t)
73 for t in dir(DType)
74 if not callable(getattr(DType, t))
75 and not t.startswith("__")
76 and getattr(DType, t) not in excludes
77 }
78
79
80def usableDTypes(*, excludes=None):
81 """Get a set of usable DType values, optionally excluding some values.
82
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010083 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
84 addition to the excludes specified by the caller, as the serializer lib
85 does not support them.
86 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
87 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088
89 Args:
90 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
91
92 Returns:
93 A set of DType values
94 """
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010095 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 omit.update(excludes if excludes else ())
97 return allDTypes(excludes=omit)
98
99
100def product(shape):
101 value = 1
102 for n in shape:
103 value *= n
104 return value
James Ward8b390432022-08-12 20:48:56 +0100105
106
107def get_accum_dtype_from_tgTypes(dtypes):
108 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100109 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
110 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100111
112
113def get_wrong_output_type(op_name, rng, input_dtype):
114 if op_name == "fully_connected" or op_name == "matmul":
115 if input_dtype == DType.INT8:
116 incorrect_types = (
117 DType.INT4,
118 DType.INT8,
119 DType.INT16,
120 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100121 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100122 DType.FP16,
123 )
124 elif input_dtype == DType.INT16:
125 incorrect_types = (
126 DType.INT4,
127 DType.INT8,
128 DType.INT16,
129 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100130 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100131 DType.FP16,
132 )
James Ward24dbc422022-10-19 12:20:31 +0100133 elif (
134 input_dtype == DType.FP32
135 or input_dtype == DType.FP16
136 or input_dtype == DType.BF16
137 ):
James Ward8b390432022-08-12 20:48:56 +0100138 incorrect_types = (
139 DType.INT4,
140 DType.INT8,
141 DType.INT16,
142 DType.INT32,
143 DType.INT48,
144 )
145 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100146
147
148def float32_is_valid_bfloat16(f):
149 """Return True if float value is valid bfloat16."""
150 f32_bits = get_float32_bitstring(f)
151 return f32_bits[16:] == "0" * 16
152
153
154def get_float32_bitstring(f):
155 """Return a big-endian string of bits representing a 32 bit float."""
156 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
157 return f"{f32_bits_as_int:032b}"
158
159
160def float32_to_bfloat16(f):
161 """Turns fp32 value into bfloat16 by flooring.
162
163 Floors the least significant 16 bits of the input
164 fp32 value and returns this valid bfloat16 representation as fp32.
165 For simplicity during bit-wrangling, ignores underlying system
166 endianness and interprets as big-endian.
167 Returns a bf16-valid float following system's native byte order.
168 """
169 f32_bits = get_float32_bitstring(f)
170 f32_floored_bits = f32_bits[:16] + "0" * 16
171
172 # Assume sys.byteorder matches system's underlying float byteorder
173 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
174 return struct.unpack("@f", fp_bytes)[0] # native byteorder
175
176
177vect_f32_to_bf16 = np.vectorize(
178 float32_to_bfloat16, otypes=(np.float32,)
179) # NumPy vectorize: applies function to vector faster than looping