Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 1 | # Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. |
| 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 16 | # Description: |
| 17 | # Defines the basic numeric type classes for tensors. |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 18 | import enum |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 19 | from typing import Any |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 20 | |
Diego Russo | ea6111a | 2020-04-14 18:41:58 +0100 | [diff] [blame] | 21 | from .numeric_util import round_up_divide |
| 22 | |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 23 | |
| 24 | class BaseType(enum.Flag): |
| 25 | Signed = 1 |
| 26 | Unsigned = 2 |
| 27 | Asymmetric = 4 |
| 28 | Int = 8 |
| 29 | SignedInt = Int | Signed |
| 30 | UnsignedInt = Int | Unsigned |
| 31 | AsymmSInt = Int | Asymmetric | Signed |
| 32 | AsymmUInt = Int | Asymmetric | Unsigned |
| 33 | Float = 16 |
| 34 | BFloat = 32 |
| 35 | Bool = 64 |
| 36 | String = 128 |
| 37 | Resource = 256 |
| 38 | Variant = 512 |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 39 | Complex = 1024 |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 40 | |
| 41 | |
| 42 | class DataType: |
| 43 | """Defines a data type. Consists of a base type, and the number of bits used for this type""" |
| 44 | |
| 45 | __slots__ = "type", "bits" |
| 46 | |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame^] | 47 | int4: Any |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 48 | int8: Any |
| 49 | int16: Any |
| 50 | int32: Any |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame^] | 51 | int48: Any |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 52 | int64: Any |
| 53 | uint8: Any |
| 54 | uint16: Any |
| 55 | uint32: Any |
| 56 | uint64: Any |
| 57 | quint4: Any |
| 58 | quint8: Any |
| 59 | quint12: Any |
| 60 | quint16: Any |
| 61 | quint32: Any |
| 62 | qint4: Any |
| 63 | qint8: Any |
| 64 | qint12: Any |
| 65 | qint16: Any |
| 66 | qint32: Any |
| 67 | float16: Any |
| 68 | float32: Any |
| 69 | float64: Any |
| 70 | string: Any |
| 71 | bool: Any |
| 72 | resource: Any |
| 73 | variant: Any |
| 74 | complex64: Any |
| 75 | complex128: Any |
| 76 | |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 77 | def __init__(self, type_, bits): |
| 78 | self.type = type_ |
| 79 | self.bits = bits |
| 80 | |
| 81 | def __eq__(self, other): |
| 82 | return self.type == other.type and self.bits == other.bits |
| 83 | |
| 84 | def __hash__(self): |
| 85 | return hash((self.type, self.bits)) |
| 86 | |
| 87 | def size_in_bytes(self): |
| 88 | return round_up_divide(self.bits, 8) |
| 89 | |
| 90 | def size_in_bits(self): |
| 91 | return self.bits |
| 92 | |
| 93 | def __str__(self): |
| 94 | stem, needs_format = DataType.stem_name[self.type] |
| 95 | if not needs_format: |
| 96 | return stem |
| 97 | else: |
| 98 | return stem % (self.bits,) |
| 99 | |
| 100 | __repr__ = __str__ |
| 101 | |
| 102 | stem_name = { |
| 103 | BaseType.UnsignedInt: ("uint%s", True), |
| 104 | BaseType.SignedInt: ("int%s", True), |
| 105 | BaseType.AsymmUInt: ("quint%s", True), |
| 106 | BaseType.AsymmSInt: ("qint%s", True), |
| 107 | BaseType.Float: ("float%s", True), |
| 108 | BaseType.BFloat: ("bfloat%s", True), |
| 109 | BaseType.Bool: ("bool", False), |
| 110 | BaseType.String: ("string", False), |
| 111 | BaseType.Resource: ("resource", False), |
| 112 | BaseType.Variant: ("variant", False), |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 113 | BaseType.Complex: ("complex%s", True), |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 114 | } |
| 115 | |
| 116 | |
| 117 | # generate the standard set of data types |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame^] | 118 | DataType.int4 = DataType(BaseType.SignedInt, 4) |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 119 | DataType.int8 = DataType(BaseType.SignedInt, 8) |
| 120 | DataType.int16 = DataType(BaseType.SignedInt, 16) |
| 121 | DataType.int32 = DataType(BaseType.SignedInt, 32) |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame^] | 122 | DataType.int48 = DataType(BaseType.SignedInt, 48) |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 123 | DataType.int64 = DataType(BaseType.SignedInt, 64) |
| 124 | |
| 125 | DataType.uint8 = DataType(BaseType.UnsignedInt, 8) |
| 126 | DataType.uint16 = DataType(BaseType.UnsignedInt, 16) |
| 127 | DataType.uint32 = DataType(BaseType.UnsignedInt, 32) |
| 128 | DataType.uint64 = DataType(BaseType.UnsignedInt, 64) |
| 129 | |
| 130 | DataType.quint4 = DataType(BaseType.AsymmUInt, 4) |
| 131 | DataType.quint8 = DataType(BaseType.AsymmUInt, 8) |
| 132 | DataType.quint12 = DataType(BaseType.AsymmUInt, 12) |
| 133 | DataType.quint16 = DataType(BaseType.AsymmUInt, 16) |
| 134 | DataType.quint32 = DataType(BaseType.AsymmUInt, 32) |
| 135 | |
| 136 | DataType.qint4 = DataType(BaseType.AsymmSInt, 4) |
| 137 | DataType.qint8 = DataType(BaseType.AsymmSInt, 8) |
| 138 | DataType.qint12 = DataType(BaseType.AsymmSInt, 12) |
| 139 | DataType.qint16 = DataType(BaseType.AsymmSInt, 16) |
| 140 | DataType.qint32 = DataType(BaseType.AsymmSInt, 32) |
| 141 | |
| 142 | DataType.float16 = DataType(BaseType.Float, 16) |
| 143 | DataType.float32 = DataType(BaseType.Float, 32) |
| 144 | DataType.float64 = DataType(BaseType.Float, 64) |
| 145 | |
| 146 | DataType.string = DataType(BaseType.String, 64) |
| 147 | DataType.bool = DataType(BaseType.Bool, 8) |
| 148 | DataType.resource = DataType(BaseType.Resource, 8) |
| 149 | DataType.variant = DataType(BaseType.Variant, 8) |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 150 | DataType.complex64 = DataType(BaseType.Complex, 64) |
Jacob Bohlin | 8daf6b7 | 2020-09-15 16:28:35 +0200 | [diff] [blame] | 151 | DataType.complex128 = DataType(BaseType.Complex, 128) |