blob: 29ae8989b6befc60e3391c416b6ef0259e9d774d [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
5
6import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from tosa.DType import DType
8
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01009# Maximum dimension size for output and inputs for RESIZE
10MAX_RESIZE_DIMENSION = 16384
11
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010012DTYPE_ATTRIBUTES = {
13 DType.BOOL: {"str": "b", "width": 1},
14 DType.INT4: {"str": "i4", "width": 4},
15 DType.INT8: {"str": "i8", "width": 8},
16 DType.UINT8: {"str": "u8", "width": 8},
17 DType.INT16: {"str": "i16", "width": 16},
18 DType.UINT16: {"str": "u16", "width": 16},
19 DType.INT32: {"str": "i32", "width": 32},
20 DType.INT48: {"str": "i48", "width": 48},
21 DType.FP16: {"str": "f16", "width": 16},
James Ward24dbc422022-10-19 12:20:31 +010022 DType.BF16: {"str": "bf16", "width": 16},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010023 DType.FP32: {"str": "f32", "width": 32},
24}
25
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010026
27def valueToName(item, value):
28 """Get the name of an attribute with the given value.
29
30 This convenience function is needed to print meaningful names for
31 the values of the tosa.Op.Op and tosa.DType.DType classes.
32 This would not be necessary if they were subclasses of Enum, or
33 IntEnum, which, sadly, they are not.
34
35 Args:
36 item: The class, or object, to find the value in
37 value: The value to find
38
39 Example, to get the name of a DType value:
40
41 name = valueToName(DType, DType.INT8) # returns 'INT8'
42 name = valueToName(DType, 4) # returns 'INT8'
43
44 Returns:
45 The name of the first attribute found with a matching value,
46
47 Raises:
48 ValueError if the value is not found
49 """
50 for attr in dir(item):
51 if getattr(item, attr) == value:
52 return attr
53 raise ValueError(f"value ({value}) not found")
54
55
56def allDTypes(*, excludes=None):
57 """Get a set of all DType values, optionally excluding some values.
58
59 This convenience function is needed to provide a sequence of DType values.
60 This would be much easier if DType was a subclass of Enum, or IntEnum,
61 as we could then iterate over the values directly, instead of using
62 dir() to find the attributes and then check if they are what we want.
63
64 Args:
65 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
66
67 Returns:
68 A set of DType values
69 """
70 excludes = () if not excludes else excludes
71 return {
72 getattr(DType, t)
73 for t in dir(DType)
74 if not callable(getattr(DType, t))
75 and not t.startswith("__")
76 and getattr(DType, t) not in excludes
77 }
78
79
80def usableDTypes(*, excludes=None):
81 """Get a set of usable DType values, optionally excluding some values.
82
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010083 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
84 addition to the excludes specified by the caller, as the serializer lib
85 does not support them.
86 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
87 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088
89 Args:
90 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
91
92 Returns:
93 A set of DType values
94 """
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010095 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 omit.update(excludes if excludes else ())
97 return allDTypes(excludes=omit)
98
99
100def product(shape):
101 value = 1
102 for n in shape:
103 value *= n
104 return value
James Ward8b390432022-08-12 20:48:56 +0100105
106
107def get_accum_dtype_from_tgTypes(dtypes):
108 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100109 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
110 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100111
112
113def get_wrong_output_type(op_name, rng, input_dtype):
114 if op_name == "fully_connected" or op_name == "matmul":
115 if input_dtype == DType.INT8:
116 incorrect_types = (
117 DType.INT4,
118 DType.INT8,
119 DType.INT16,
120 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100121 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100122 DType.FP16,
123 )
124 elif input_dtype == DType.INT16:
125 incorrect_types = (
126 DType.INT4,
127 DType.INT8,
128 DType.INT16,
129 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100130 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100131 DType.FP16,
132 )
James Ward24dbc422022-10-19 12:20:31 +0100133 elif (
134 input_dtype == DType.FP32
135 or input_dtype == DType.FP16
136 or input_dtype == DType.BF16
137 ):
James Ward8b390432022-08-12 20:48:56 +0100138 incorrect_types = (
139 DType.INT4,
140 DType.INT8,
141 DType.INT16,
142 DType.INT32,
143 DType.INT48,
144 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000145 else:
146 # Assume all types but the input type are incorrect
147 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100148 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100149
150
151def float32_is_valid_bfloat16(f):
152 """Return True if float value is valid bfloat16."""
153 f32_bits = get_float32_bitstring(f)
154 return f32_bits[16:] == "0" * 16
155
156
157def get_float32_bitstring(f):
158 """Return a big-endian string of bits representing a 32 bit float."""
159 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
160 return f"{f32_bits_as_int:032b}"
161
162
163def float32_to_bfloat16(f):
164 """Turns fp32 value into bfloat16 by flooring.
165
166 Floors the least significant 16 bits of the input
167 fp32 value and returns this valid bfloat16 representation as fp32.
168 For simplicity during bit-wrangling, ignores underlying system
169 endianness and interprets as big-endian.
170 Returns a bf16-valid float following system's native byte order.
171 """
172 f32_bits = get_float32_bitstring(f)
173 f32_floored_bits = f32_bits[:16] + "0" * 16
174
175 # Assume sys.byteorder matches system's underlying float byteorder
176 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
177 return struct.unpack("@f", fp_bytes)[0] # native byteorder
178
179
180vect_f32_to_bf16 = np.vectorize(
181 float32_to_bfloat16, otypes=(np.float32,)
182) # NumPy vectorize: applies function to vector faster than looping