blob: a4ef31aecda8fb9f15b006aaa6d788540f2d7778 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from tosa.DType import DType
4
5
6def valueToName(item, value):
7 """Get the name of an attribute with the given value.
8
9 This convenience function is needed to print meaningful names for
10 the values of the tosa.Op.Op and tosa.DType.DType classes.
11 This would not be necessary if they were subclasses of Enum, or
12 IntEnum, which, sadly, they are not.
13
14 Args:
15 item: The class, or object, to find the value in
16 value: The value to find
17
18 Example, to get the name of a DType value:
19
20 name = valueToName(DType, DType.INT8) # returns 'INT8'
21 name = valueToName(DType, 4) # returns 'INT8'
22
23 Returns:
24 The name of the first attribute found with a matching value,
25
26 Raises:
27 ValueError if the value is not found
28 """
29 for attr in dir(item):
30 if getattr(item, attr) == value:
31 return attr
32 raise ValueError(f"value ({value}) not found")
33
34
35def allDTypes(*, excludes=None):
36 """Get a set of all DType values, optionally excluding some values.
37
38 This convenience function is needed to provide a sequence of DType values.
39 This would be much easier if DType was a subclass of Enum, or IntEnum,
40 as we could then iterate over the values directly, instead of using
41 dir() to find the attributes and then check if they are what we want.
42
43 Args:
44 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
45
46 Returns:
47 A set of DType values
48 """
49 excludes = () if not excludes else excludes
50 return {
51 getattr(DType, t)
52 for t in dir(DType)
53 if not callable(getattr(DType, t))
54 and not t.startswith("__")
55 and getattr(DType, t) not in excludes
56 }
57
58
59def usableDTypes(*, excludes=None):
60 """Get a set of usable DType values, optionally excluding some values.
61
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010062 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
63 addition to the excludes specified by the caller, as the serializer lib
64 does not support them.
65 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
66 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010067
68 Args:
69 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
70
71 Returns:
72 A set of DType values
73 """
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 omit.update(excludes if excludes else ())
76 return allDTypes(excludes=omit)
77
78
79def product(shape):
80 value = 1
81 for n in shape:
82 value *= n
83 return value