blob: 3cd03707191fe59635cadd15689edc7a26a85217 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
James Ward24dbc422022-10-19 12:20:31 +01003import struct
4import sys
5
6import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from tosa.DType import DType
8
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01009# Maximum dimension size for output and inputs for RESIZE
10MAX_RESIZE_DIMENSION = 16384
11
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010012DTYPE_ATTRIBUTES = {
13 DType.BOOL: {"str": "b", "width": 1},
14 DType.INT4: {"str": "i4", "width": 4},
15 DType.INT8: {"str": "i8", "width": 8},
16 DType.UINT8: {"str": "u8", "width": 8},
17 DType.INT16: {"str": "i16", "width": 16},
18 DType.UINT16: {"str": "u16", "width": 16},
19 DType.INT32: {"str": "i32", "width": 32},
20 DType.INT48: {"str": "i48", "width": 48},
Won Jeona21b2e82023-08-10 10:33:01 +000021 DType.SHAPE: {"str": "i64", "width": 64},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010022 DType.FP16: {"str": "f16", "width": 16},
James Ward24dbc422022-10-19 12:20:31 +010023 DType.BF16: {"str": "bf16", "width": 16},
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010024 DType.FP32: {"str": "f32", "width": 32},
25}
26
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010027
28def valueToName(item, value):
29 """Get the name of an attribute with the given value.
30
31 This convenience function is needed to print meaningful names for
32 the values of the tosa.Op.Op and tosa.DType.DType classes.
33 This would not be necessary if they were subclasses of Enum, or
34 IntEnum, which, sadly, they are not.
35
36 Args:
37 item: The class, or object, to find the value in
38 value: The value to find
39
40 Example, to get the name of a DType value:
41
42 name = valueToName(DType, DType.INT8) # returns 'INT8'
43 name = valueToName(DType, 4) # returns 'INT8'
44
45 Returns:
46 The name of the first attribute found with a matching value,
47
48 Raises:
49 ValueError if the value is not found
50 """
51 for attr in dir(item):
52 if getattr(item, attr) == value:
53 return attr
54 raise ValueError(f"value ({value}) not found")
55
56
57def allDTypes(*, excludes=None):
58 """Get a set of all DType values, optionally excluding some values.
59
60 This convenience function is needed to provide a sequence of DType values.
61 This would be much easier if DType was a subclass of Enum, or IntEnum,
62 as we could then iterate over the values directly, instead of using
63 dir() to find the attributes and then check if they are what we want.
64
65 Args:
66 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
67
68 Returns:
69 A set of DType values
70 """
71 excludes = () if not excludes else excludes
72 return {
73 getattr(DType, t)
74 for t in dir(DType)
75 if not callable(getattr(DType, t))
76 and not t.startswith("__")
77 and getattr(DType, t) not in excludes
78 }
79
80
81def usableDTypes(*, excludes=None):
82 """Get a set of usable DType values, optionally excluding some values.
83
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010084 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
85 addition to the excludes specified by the caller, as the serializer lib
86 does not support them.
87 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
88 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089
90 Args:
91 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
92
93 Returns:
94 A set of DType values
95 """
Jeremy Johnson0633c3a2023-08-22 16:55:08 +010096 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16, DType.SHAPE}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 omit.update(excludes if excludes else ())
98 return allDTypes(excludes=omit)
99
100
101def product(shape):
102 value = 1
103 for n in shape:
104 value *= n
105 return value
James Ward8b390432022-08-12 20:48:56 +0100106
107
108def get_accum_dtype_from_tgTypes(dtypes):
109 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100110 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
111 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100112
113
114def get_wrong_output_type(op_name, rng, input_dtype):
115 if op_name == "fully_connected" or op_name == "matmul":
116 if input_dtype == DType.INT8:
117 incorrect_types = (
118 DType.INT4,
119 DType.INT8,
120 DType.INT16,
121 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100122 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100123 DType.FP16,
124 )
125 elif input_dtype == DType.INT16:
126 incorrect_types = (
127 DType.INT4,
128 DType.INT8,
129 DType.INT16,
130 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100131 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100132 DType.FP16,
133 )
James Ward24dbc422022-10-19 12:20:31 +0100134 elif (
135 input_dtype == DType.FP32
136 or input_dtype == DType.FP16
137 or input_dtype == DType.BF16
138 ):
James Ward8b390432022-08-12 20:48:56 +0100139 incorrect_types = (
140 DType.INT4,
141 DType.INT8,
142 DType.INT16,
143 DType.INT32,
144 DType.INT48,
145 )
Jeremy Johnson05c711e2022-12-12 18:00:41 +0000146 else:
147 # Assume all types but the input type are incorrect
148 incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
James Ward8b390432022-08-12 20:48:56 +0100149 return rng.choice(a=incorrect_types)
James Ward24dbc422022-10-19 12:20:31 +0100150
151
Luke Huttona4e48ca2023-02-22 11:53:48 +0000152def get_rank_mismatch_shape(rng, output_shape):
153 """
154 Extends the rank of the provided output_shape by
155 an arbitrary amount but ensures the total element
156 count remains the same.
157 """
158 rank_modifier = rng.choice([1, 2, 3])
159 output_shape += [1] * rank_modifier
160 return output_shape
161
162
James Ward24dbc422022-10-19 12:20:31 +0100163def float32_is_valid_bfloat16(f):
164 """Return True if float value is valid bfloat16."""
165 f32_bits = get_float32_bitstring(f)
166 return f32_bits[16:] == "0" * 16
167
168
169def get_float32_bitstring(f):
170 """Return a big-endian string of bits representing a 32 bit float."""
171 f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
172 return f"{f32_bits_as_int:032b}"
173
174
175def float32_to_bfloat16(f):
176 """Turns fp32 value into bfloat16 by flooring.
177
178 Floors the least significant 16 bits of the input
179 fp32 value and returns this valid bfloat16 representation as fp32.
180 For simplicity during bit-wrangling, ignores underlying system
181 endianness and interprets as big-endian.
182 Returns a bf16-valid float following system's native byte order.
183 """
184 f32_bits = get_float32_bitstring(f)
185 f32_floored_bits = f32_bits[:16] + "0" * 16
186
187 # Assume sys.byteorder matches system's underlying float byteorder
188 fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
189 return struct.unpack("@f", fp_bytes)[0] # native byteorder
190
191
192vect_f32_to_bf16 = np.vectorize(
193 float32_to_bfloat16, otypes=(np.float32,)
194) # NumPy vectorize: applies function to vector faster than looping