blob: 3ad642ad5ee8400f794f641fc82bedcb9b32a781 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Defines the basic numeric type classes for tensors.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010019from typing import Any
Tim Hall79d07d22020-04-27 18:20:16 +010020
Diego Russoea6111a2020-04-14 18:41:58 +010021from .numeric_util import round_up_divide
22
Tim Hall79d07d22020-04-27 18:20:16 +010023
24class BaseType(enum.Flag):
25 Signed = 1
26 Unsigned = 2
27 Asymmetric = 4
28 Int = 8
29 SignedInt = Int | Signed
30 UnsignedInt = Int | Unsigned
31 AsymmSInt = Int | Asymmetric | Signed
32 AsymmUInt = Int | Asymmetric | Unsigned
33 Float = 16
34 BFloat = 32
35 Bool = 64
36 String = 128
37 Resource = 256
38 Variant = 512
Jacob Bohlinf767b932020-08-13 15:32:45 +020039 Complex = 1024
Tim Hall79d07d22020-04-27 18:20:16 +010040
41
42class DataType:
43 """Defines a data type. Consists of a base type, and the number of bits used for this type"""
44
45 __slots__ = "type", "bits"
46
Dwight Lidman9b43f842020-12-08 17:56:44 +010047 int8: Any
48 int16: Any
49 int32: Any
50 int64: Any
51 uint8: Any
52 uint16: Any
53 uint32: Any
54 uint64: Any
55 quint4: Any
56 quint8: Any
57 quint12: Any
58 quint16: Any
59 quint32: Any
60 qint4: Any
61 qint8: Any
62 qint12: Any
63 qint16: Any
64 qint32: Any
65 float16: Any
66 float32: Any
67 float64: Any
68 string: Any
69 bool: Any
70 resource: Any
71 variant: Any
72 complex64: Any
73 complex128: Any
74
Tim Hall79d07d22020-04-27 18:20:16 +010075 def __init__(self, type_, bits):
76 self.type = type_
77 self.bits = bits
78
79 def __eq__(self, other):
80 return self.type == other.type and self.bits == other.bits
81
82 def __hash__(self):
83 return hash((self.type, self.bits))
84
85 def size_in_bytes(self):
86 return round_up_divide(self.bits, 8)
87
88 def size_in_bits(self):
89 return self.bits
90
91 def __str__(self):
92 stem, needs_format = DataType.stem_name[self.type]
93 if not needs_format:
94 return stem
95 else:
96 return stem % (self.bits,)
97
98 __repr__ = __str__
99
100 stem_name = {
101 BaseType.UnsignedInt: ("uint%s", True),
102 BaseType.SignedInt: ("int%s", True),
103 BaseType.AsymmUInt: ("quint%s", True),
104 BaseType.AsymmSInt: ("qint%s", True),
105 BaseType.Float: ("float%s", True),
106 BaseType.BFloat: ("bfloat%s", True),
107 BaseType.Bool: ("bool", False),
108 BaseType.String: ("string", False),
109 BaseType.Resource: ("resource", False),
110 BaseType.Variant: ("variant", False),
Jacob Bohlinf767b932020-08-13 15:32:45 +0200111 BaseType.Complex: ("complex%s", True),
Tim Hall79d07d22020-04-27 18:20:16 +0100112 }
113
114
115# generate the standard set of data types
116DataType.int8 = DataType(BaseType.SignedInt, 8)
117DataType.int16 = DataType(BaseType.SignedInt, 16)
118DataType.int32 = DataType(BaseType.SignedInt, 32)
119DataType.int64 = DataType(BaseType.SignedInt, 64)
120
121DataType.uint8 = DataType(BaseType.UnsignedInt, 8)
122DataType.uint16 = DataType(BaseType.UnsignedInt, 16)
123DataType.uint32 = DataType(BaseType.UnsignedInt, 32)
124DataType.uint64 = DataType(BaseType.UnsignedInt, 64)
125
126DataType.quint4 = DataType(BaseType.AsymmUInt, 4)
127DataType.quint8 = DataType(BaseType.AsymmUInt, 8)
128DataType.quint12 = DataType(BaseType.AsymmUInt, 12)
129DataType.quint16 = DataType(BaseType.AsymmUInt, 16)
130DataType.quint32 = DataType(BaseType.AsymmUInt, 32)
131
132DataType.qint4 = DataType(BaseType.AsymmSInt, 4)
133DataType.qint8 = DataType(BaseType.AsymmSInt, 8)
134DataType.qint12 = DataType(BaseType.AsymmSInt, 12)
135DataType.qint16 = DataType(BaseType.AsymmSInt, 16)
136DataType.qint32 = DataType(BaseType.AsymmSInt, 32)
137
138DataType.float16 = DataType(BaseType.Float, 16)
139DataType.float32 = DataType(BaseType.Float, 32)
140DataType.float64 = DataType(BaseType.Float, 64)
141
142DataType.string = DataType(BaseType.String, 64)
143DataType.bool = DataType(BaseType.Bool, 8)
144DataType.resource = DataType(BaseType.Resource, 8)
145DataType.variant = DataType(BaseType.Variant, 8)
Jacob Bohlinf767b932020-08-13 15:32:45 +0200146DataType.complex64 = DataType(BaseType.Complex, 64)
Jacob Bohlin8daf6b72020-09-15 16:28:35 +0200147DataType.complex128 = DataType(BaseType.Complex, 128)