blob: 104d9bb40b0d2723c52fe45e3aa4a13d9e799a9a [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from tosa.DType import DType
4
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005# Maximum dimension size for output and inputs for RESIZE
6MAX_RESIZE_DIMENSION = 16384
7
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01008DTYPE_ATTRIBUTES = {
9 DType.BOOL: {"str": "b", "width": 1},
10 DType.INT4: {"str": "i4", "width": 4},
11 DType.INT8: {"str": "i8", "width": 8},
12 DType.UINT8: {"str": "u8", "width": 8},
13 DType.INT16: {"str": "i16", "width": 16},
14 DType.UINT16: {"str": "u16", "width": 16},
15 DType.INT32: {"str": "i32", "width": 32},
16 DType.INT48: {"str": "i48", "width": 48},
17 DType.FP16: {"str": "f16", "width": 16},
18 DType.FP32: {"str": "f32", "width": 32},
19}
20
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010021
22def valueToName(item, value):
23 """Get the name of an attribute with the given value.
24
25 This convenience function is needed to print meaningful names for
26 the values of the tosa.Op.Op and tosa.DType.DType classes.
27 This would not be necessary if they were subclasses of Enum, or
28 IntEnum, which, sadly, they are not.
29
30 Args:
31 item: The class, or object, to find the value in
32 value: The value to find
33
34 Example, to get the name of a DType value:
35
36 name = valueToName(DType, DType.INT8) # returns 'INT8'
37 name = valueToName(DType, 4) # returns 'INT8'
38
39 Returns:
40 The name of the first attribute found with a matching value,
41
42 Raises:
43 ValueError if the value is not found
44 """
45 for attr in dir(item):
46 if getattr(item, attr) == value:
47 return attr
48 raise ValueError(f"value ({value}) not found")
49
50
51def allDTypes(*, excludes=None):
52 """Get a set of all DType values, optionally excluding some values.
53
54 This convenience function is needed to provide a sequence of DType values.
55 This would be much easier if DType was a subclass of Enum, or IntEnum,
56 as we could then iterate over the values directly, instead of using
57 dir() to find the attributes and then check if they are what we want.
58
59 Args:
60 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
61
62 Returns:
63 A set of DType values
64 """
65 excludes = () if not excludes else excludes
66 return {
67 getattr(DType, t)
68 for t in dir(DType)
69 if not callable(getattr(DType, t))
70 and not t.startswith("__")
71 and getattr(DType, t) not in excludes
72 }
73
74
75def usableDTypes(*, excludes=None):
76 """Get a set of usable DType values, optionally excluding some values.
77
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010078 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
79 addition to the excludes specified by the caller, as the serializer lib
80 does not support them.
81 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
82 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010083
84 Args:
85 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
86
87 Returns:
88 A set of DType values
89 """
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010090 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 omit.update(excludes if excludes else ())
92 return allDTypes(excludes=omit)
93
94
95def product(shape):
96 value = 1
97 for n in shape:
98 value *= n
99 return value
James Ward8b390432022-08-12 20:48:56 +0100100
101
102def get_accum_dtype_from_tgTypes(dtypes):
103 # Get accumulate data-type from the test generator's defined types
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100104 assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
105 return dtypes[-1]
James Ward8b390432022-08-12 20:48:56 +0100106
107
108def get_wrong_output_type(op_name, rng, input_dtype):
109 if op_name == "fully_connected" or op_name == "matmul":
110 if input_dtype == DType.INT8:
111 incorrect_types = (
112 DType.INT4,
113 DType.INT8,
114 DType.INT16,
115 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100116 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100117 DType.FP16,
118 )
119 elif input_dtype == DType.INT16:
120 incorrect_types = (
121 DType.INT4,
122 DType.INT8,
123 DType.INT16,
124 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100125 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100126 DType.FP16,
127 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100128 elif input_dtype == DType.FP32 or input_dtype == DType.FP16:
James Ward8b390432022-08-12 20:48:56 +0100129 incorrect_types = (
130 DType.INT4,
131 DType.INT8,
132 DType.INT16,
133 DType.INT32,
134 DType.INT48,
135 )
136 return rng.choice(a=incorrect_types)