blob: 6a689d05b8b27aa1240ae16963ab8415a94b378c [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from tosa.DType import DType
4
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005# Maximum dimension size for output and inputs for RESIZE
6MAX_RESIZE_DIMENSION = 16384
7
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008
9def valueToName(item, value):
10 """Get the name of an attribute with the given value.
11
12 This convenience function is needed to print meaningful names for
13 the values of the tosa.Op.Op and tosa.DType.DType classes.
14 This would not be necessary if they were subclasses of Enum, or
15 IntEnum, which, sadly, they are not.
16
17 Args:
18 item: The class, or object, to find the value in
19 value: The value to find
20
21 Example, to get the name of a DType value:
22
23 name = valueToName(DType, DType.INT8) # returns 'INT8'
24 name = valueToName(DType, 4) # returns 'INT8'
25
26 Returns:
27 The name of the first attribute found with a matching value,
28
29 Raises:
30 ValueError if the value is not found
31 """
32 for attr in dir(item):
33 if getattr(item, attr) == value:
34 return attr
35 raise ValueError(f"value ({value}) not found")
36
37
38def allDTypes(*, excludes=None):
39 """Get a set of all DType values, optionally excluding some values.
40
41 This convenience function is needed to provide a sequence of DType values.
42 This would be much easier if DType was a subclass of Enum, or IntEnum,
43 as we could then iterate over the values directly, instead of using
44 dir() to find the attributes and then check if they are what we want.
45
46 Args:
47 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
48
49 Returns:
50 A set of DType values
51 """
52 excludes = () if not excludes else excludes
53 return {
54 getattr(DType, t)
55 for t in dir(DType)
56 if not callable(getattr(DType, t))
57 and not t.startswith("__")
58 and getattr(DType, t) not in excludes
59 }
60
61
62def usableDTypes(*, excludes=None):
63 """Get a set of usable DType values, optionally excluding some values.
64
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010065 Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
66 addition to the excludes specified by the caller, as the serializer lib
67 does not support them.
68 If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
69 instead.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070
71 Args:
72 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
73
74 Returns:
75 A set of DType values
76 """
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010077 omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010078 omit.update(excludes if excludes else ())
79 return allDTypes(excludes=omit)
80
81
82def product(shape):
83 value = 1
84 for n in shape:
85 value *= n
86 return value