blob: a4b7b5376747b75811aab915b9d54b749b6471e2 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Defines the basic numeric type classes for tensors.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
Diego Russoea6111a2020-04-14 18:41:58 +010020from .numeric_util import round_up_divide
21
Tim Hall79d07d22020-04-27 18:20:16 +010022
23class BaseType(enum.Flag):
24 Signed = 1
25 Unsigned = 2
26 Asymmetric = 4
27 Int = 8
28 SignedInt = Int | Signed
29 UnsignedInt = Int | Unsigned
30 AsymmSInt = Int | Asymmetric | Signed
31 AsymmUInt = Int | Asymmetric | Unsigned
32 Float = 16
33 BFloat = 32
34 Bool = 64
35 String = 128
36 Resource = 256
37 Variant = 512
Jacob Bohlinf767b932020-08-13 15:32:45 +020038 Complex = 1024
Tim Hall79d07d22020-04-27 18:20:16 +010039
40
41class DataType:
42 """Defines a data type. Consists of a base type, and the number of bits used for this type"""
43
44 __slots__ = "type", "bits"
45
46 def __init__(self, type_, bits):
47 self.type = type_
48 self.bits = bits
49
50 def __eq__(self, other):
51 return self.type == other.type and self.bits == other.bits
52
53 def __hash__(self):
54 return hash((self.type, self.bits))
55
56 def size_in_bytes(self):
57 return round_up_divide(self.bits, 8)
58
59 def size_in_bits(self):
60 return self.bits
61
62 def __str__(self):
63 stem, needs_format = DataType.stem_name[self.type]
64 if not needs_format:
65 return stem
66 else:
67 return stem % (self.bits,)
68
69 __repr__ = __str__
70
71 stem_name = {
72 BaseType.UnsignedInt: ("uint%s", True),
73 BaseType.SignedInt: ("int%s", True),
74 BaseType.AsymmUInt: ("quint%s", True),
75 BaseType.AsymmSInt: ("qint%s", True),
76 BaseType.Float: ("float%s", True),
77 BaseType.BFloat: ("bfloat%s", True),
78 BaseType.Bool: ("bool", False),
79 BaseType.String: ("string", False),
80 BaseType.Resource: ("resource", False),
81 BaseType.Variant: ("variant", False),
Jacob Bohlinf767b932020-08-13 15:32:45 +020082 BaseType.Complex: ("complex%s", True),
Tim Hall79d07d22020-04-27 18:20:16 +010083 }
84
85
86# generate the standard set of data types
87DataType.int8 = DataType(BaseType.SignedInt, 8)
88DataType.int16 = DataType(BaseType.SignedInt, 16)
89DataType.int32 = DataType(BaseType.SignedInt, 32)
90DataType.int64 = DataType(BaseType.SignedInt, 64)
91
92DataType.uint8 = DataType(BaseType.UnsignedInt, 8)
93DataType.uint16 = DataType(BaseType.UnsignedInt, 16)
94DataType.uint32 = DataType(BaseType.UnsignedInt, 32)
95DataType.uint64 = DataType(BaseType.UnsignedInt, 64)
96
97DataType.quint4 = DataType(BaseType.AsymmUInt, 4)
98DataType.quint8 = DataType(BaseType.AsymmUInt, 8)
99DataType.quint12 = DataType(BaseType.AsymmUInt, 12)
100DataType.quint16 = DataType(BaseType.AsymmUInt, 16)
101DataType.quint32 = DataType(BaseType.AsymmUInt, 32)
102
103DataType.qint4 = DataType(BaseType.AsymmSInt, 4)
104DataType.qint8 = DataType(BaseType.AsymmSInt, 8)
105DataType.qint12 = DataType(BaseType.AsymmSInt, 12)
106DataType.qint16 = DataType(BaseType.AsymmSInt, 16)
107DataType.qint32 = DataType(BaseType.AsymmSInt, 32)
108
109DataType.float16 = DataType(BaseType.Float, 16)
110DataType.float32 = DataType(BaseType.Float, 32)
111DataType.float64 = DataType(BaseType.Float, 64)
112
113DataType.string = DataType(BaseType.String, 64)
114DataType.bool = DataType(BaseType.Bool, 8)
115DataType.resource = DataType(BaseType.Resource, 8)
116DataType.variant = DataType(BaseType.Variant, 8)
Jacob Bohlinf767b932020-08-13 15:32:45 +0200117DataType.complex64 = DataType(BaseType.Complex, 64)
Jacob Bohlin8daf6b72020-09-15 16:28:35 +0200118DataType.complex128 = DataType(BaseType.Complex, 128)