blob: ca115a276f6bf5e71a5dc90f2e72e187ee2316a6 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from tosa.DType import DType
4
5
6def valueToName(item, value):
7 """Get the name of an attribute with the given value.
8
9 This convenience function is needed to print meaningful names for
10 the values of the tosa.Op.Op and tosa.DType.DType classes.
11 This would not be necessary if they were subclasses of Enum, or
12 IntEnum, which, sadly, they are not.
13
14 Args:
15 item: The class, or object, to find the value in
16 value: The value to find
17
18 Example, to get the name of a DType value:
19
20 name = valueToName(DType, DType.INT8) # returns 'INT8'
21 name = valueToName(DType, 4) # returns 'INT8'
22
23 Returns:
24 The name of the first attribute found with a matching value,
25
26 Raises:
27 ValueError if the value is not found
28 """
29 for attr in dir(item):
30 if getattr(item, attr) == value:
31 return attr
32 raise ValueError(f"value ({value}) not found")
33
34
35def allDTypes(*, excludes=None):
36 """Get a set of all DType values, optionally excluding some values.
37
38 This convenience function is needed to provide a sequence of DType values.
39 This would be much easier if DType was a subclass of Enum, or IntEnum,
40 as we could then iterate over the values directly, instead of using
41 dir() to find the attributes and then check if they are what we want.
42
43 Args:
44 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
45
46 Returns:
47 A set of DType values
48 """
49 excludes = () if not excludes else excludes
50 return {
51 getattr(DType, t)
52 for t in dir(DType)
53 if not callable(getattr(DType, t))
54 and not t.startswith("__")
55 and getattr(DType, t) not in excludes
56 }
57
58
59def usableDTypes(*, excludes=None):
60 """Get a set of usable DType values, optionally excluding some values.
61
62 Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
63 specified by the caller, as the serializer lib does not support them.
64 If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
65
66 Args:
67 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
68
69 Returns:
70 A set of DType values
71 """
72 omit = {DType.UNKNOWN, DType.UINT8}
73 omit.update(excludes if excludes else ())
74 return allDTypes(excludes=omit)
75
76
77def product(shape):
78 value = 1
79 for n in shape:
80 value *= n
81 return value