blob: 7fa31e7d0b30572bd727dbeb5189334cd64f18a0 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from tosa.DType import DType
4
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005# Maximum dimension size for output and inputs for RESIZE
6MAX_RESIZE_DIMENSION = 16384
7
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008
9def valueToName(item, value):
10 """Get the name of an attribute with the given value.
11
12 This convenience function is needed to print meaningful names for
13 the values of the tosa.Op.Op and tosa.DType.DType classes.
14 This would not be necessary if they were subclasses of Enum, or
15 IntEnum, which, sadly, they are not.
16
17 Args:
18 item: The class, or object, to find the value in
19 value: The value to find
20
21 Example, to get the name of a DType value:
22
23 name = valueToName(DType, DType.INT8) # returns 'INT8'
24 name = valueToName(DType, 4) # returns 'INT8'
25
26 Returns:
27 The name of the first attribute found with a matching value,
28
29 Raises:
30 ValueError if the value is not found
31 """
32 for attr in dir(item):
33 if getattr(item, attr) == value:
34 return attr
35 raise ValueError(f"value ({value}) not found")
36
37
38def allDTypes(*, excludes=None):
39 """Get a set of all DType values, optionally excluding some values.
40
41 This convenience function is needed to provide a sequence of DType values.
42 This would be much easier if DType was a subclass of Enum, or IntEnum,
43 as we could then iterate over the values directly, instead of using
44 dir() to find the attributes and then check if they are what we want.
45
46 Args:
47 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
48
49 Returns:
50 A set of DType values
51 """
52 excludes = () if not excludes else excludes
53 return {
54 getattr(DType, t)
55 for t in dir(DType)
56 if not callable(getattr(DType, t))
57 and not t.startswith("__")
58 and getattr(DType, t) not in excludes
59 }
60
61
62def usableDTypes(*, excludes=None):
63 """Get a set of usable DType values, optionally excluding some values.
64
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010065 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
66 addition to the excludes specified by the caller, as the serializer lib
67 does not support them.
68 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
69 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070
71 Args:
72 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
73
74 Returns:
75 A set of DType values
76 """
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010077 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010078 omit.update(excludes if excludes else ())
79 return allDTypes(excludes=omit)
80
81
82def product(shape):
83 value = 1
84 for n in shape:
85 value *= n
86 return value
James Ward8b390432022-08-12 20:48:56 +010087
88
89def get_accum_dtype_from_tgTypes(dtypes):
90 # Get accumulate data-type from the test generator's defined types
91 if isinstance(dtypes, list) or isinstance(dtypes, tuple):
92 return dtypes[-1]
93 else:
94 return dtypes
95
96
97def get_wrong_output_type(op_name, rng, input_dtype):
98 if op_name == "fully_connected" or op_name == "matmul":
99 if input_dtype == DType.INT8:
100 incorrect_types = (
101 DType.INT4,
102 DType.INT8,
103 DType.INT16,
104 DType.INT48,
105 DType.FLOAT,
106 DType.FP16,
107 )
108 elif input_dtype == DType.INT16:
109 incorrect_types = (
110 DType.INT4,
111 DType.INT8,
112 DType.INT16,
113 DType.INT32,
114 DType.FLOAT,
115 DType.FP16,
116 )
117 elif input_dtype == DType.FLOAT or input_dtype == DType.FP16:
118 incorrect_types = (
119 DType.INT4,
120 DType.INT8,
121 DType.INT16,
122 DType.INT32,
123 DType.INT48,
124 )
125 return rng.choice(a=incorrect_types)