blob: 829cef38d5c86c76641f6b1a6357b0073a804b5a [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Defines the basic numeric type classes for tensors.
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010020from typing import Any
Tim Hall79d07d22020-04-27 18:20:16 +010021
James Peet7519d502021-07-19 16:47:58 +010022import numpy as np
23
Diego Russoea6111a2020-04-14 18:41:58 +010024from .numeric_util import round_up_divide
25
Tim Hall79d07d22020-04-27 18:20:16 +010026
27class BaseType(enum.Flag):
28 Signed = 1
29 Unsigned = 2
30 Asymmetric = 4
31 Int = 8
32 SignedInt = Int | Signed
33 UnsignedInt = Int | Unsigned
34 AsymmSInt = Int | Asymmetric | Signed
35 AsymmUInt = Int | Asymmetric | Unsigned
36 Float = 16
37 BFloat = 32
38 Bool = 64
39 String = 128
40 Resource = 256
41 Variant = 512
Jacob Bohlinf767b932020-08-13 15:32:45 +020042 Complex = 1024
Tim Hall79d07d22020-04-27 18:20:16 +010043
44
45class DataType:
46 """Defines a data type. Consists of a base type, and the number of bits used for this type"""
47
48 __slots__ = "type", "bits"
49
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020050 int4: Any
Dwight Lidman9b43f842020-12-08 17:56:44 +010051 int8: Any
52 int16: Any
53 int32: Any
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054 int48: Any
Dwight Lidman9b43f842020-12-08 17:56:44 +010055 int64: Any
56 uint8: Any
57 uint16: Any
58 uint32: Any
59 uint64: Any
60 quint4: Any
61 quint8: Any
62 quint12: Any
63 quint16: Any
64 quint32: Any
65 qint4: Any
66 qint8: Any
67 qint12: Any
68 qint16: Any
69 qint32: Any
70 float16: Any
71 float32: Any
72 float64: Any
73 string: Any
74 bool: Any
75 resource: Any
76 variant: Any
77 complex64: Any
78 complex128: Any
79
Tim Hall79d07d22020-04-27 18:20:16 +010080 def __init__(self, type_, bits):
81 self.type = type_
82 self.bits = bits
83
84 def __eq__(self, other):
85 return self.type == other.type and self.bits == other.bits
86
87 def __hash__(self):
88 return hash((self.type, self.bits))
89
90 def size_in_bytes(self):
91 return round_up_divide(self.bits, 8)
92
93 def size_in_bits(self):
94 return self.bits
95
96 def __str__(self):
97 stem, needs_format = DataType.stem_name[self.type]
98 if not needs_format:
99 return stem
100 else:
101 return stem % (self.bits,)
102
103 __repr__ = __str__
104
James Peet7519d502021-07-19 16:47:58 +0100105 def as_numpy_type(self):
106 numpy_dtype_code = {
107 BaseType.UnsignedInt: "u",
108 BaseType.SignedInt: "i",
109 BaseType.Float: "f",
110 BaseType.Complex: "c",
111 }
112 assert self.type in numpy_dtype_code, f"Failed to interpret {self} as a numpy dtype"
113 return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes()))
114
Tim Hall79d07d22020-04-27 18:20:16 +0100115 stem_name = {
116 BaseType.UnsignedInt: ("uint%s", True),
117 BaseType.SignedInt: ("int%s", True),
118 BaseType.AsymmUInt: ("quint%s", True),
119 BaseType.AsymmSInt: ("qint%s", True),
120 BaseType.Float: ("float%s", True),
121 BaseType.BFloat: ("bfloat%s", True),
122 BaseType.Bool: ("bool", False),
123 BaseType.String: ("string", False),
124 BaseType.Resource: ("resource", False),
125 BaseType.Variant: ("variant", False),
Jacob Bohlinf767b932020-08-13 15:32:45 +0200126 BaseType.Complex: ("complex%s", True),
Tim Hall79d07d22020-04-27 18:20:16 +0100127 }
128
129
130# generate the standard set of data types
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200131DataType.int4 = DataType(BaseType.SignedInt, 4)
Tim Hall79d07d22020-04-27 18:20:16 +0100132DataType.int8 = DataType(BaseType.SignedInt, 8)
133DataType.int16 = DataType(BaseType.SignedInt, 16)
134DataType.int32 = DataType(BaseType.SignedInt, 32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200135DataType.int48 = DataType(BaseType.SignedInt, 48)
Tim Hall79d07d22020-04-27 18:20:16 +0100136DataType.int64 = DataType(BaseType.SignedInt, 64)
137
138DataType.uint8 = DataType(BaseType.UnsignedInt, 8)
139DataType.uint16 = DataType(BaseType.UnsignedInt, 16)
140DataType.uint32 = DataType(BaseType.UnsignedInt, 32)
141DataType.uint64 = DataType(BaseType.UnsignedInt, 64)
142
143DataType.quint4 = DataType(BaseType.AsymmUInt, 4)
144DataType.quint8 = DataType(BaseType.AsymmUInt, 8)
145DataType.quint12 = DataType(BaseType.AsymmUInt, 12)
146DataType.quint16 = DataType(BaseType.AsymmUInt, 16)
147DataType.quint32 = DataType(BaseType.AsymmUInt, 32)
148
149DataType.qint4 = DataType(BaseType.AsymmSInt, 4)
150DataType.qint8 = DataType(BaseType.AsymmSInt, 8)
151DataType.qint12 = DataType(BaseType.AsymmSInt, 12)
152DataType.qint16 = DataType(BaseType.AsymmSInt, 16)
153DataType.qint32 = DataType(BaseType.AsymmSInt, 32)
154
155DataType.float16 = DataType(BaseType.Float, 16)
156DataType.float32 = DataType(BaseType.Float, 32)
157DataType.float64 = DataType(BaseType.Float, 64)
158
159DataType.string = DataType(BaseType.String, 64)
160DataType.bool = DataType(BaseType.Bool, 8)
161DataType.resource = DataType(BaseType.Resource, 8)
162DataType.variant = DataType(BaseType.Variant, 8)
Jacob Bohlinf767b932020-08-13 15:32:45 +0200163DataType.complex64 = DataType(BaseType.Complex, 64)
Jacob Bohlin8daf6b72020-09-15 16:28:35 +0200164DataType.complex128 = DataType(BaseType.Complex, 128)