Rickard Bolin | bc6ee58 | 2022-11-04 08:24:29 +0000 | [diff] [blame] | 1 | # SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com> |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
Rickard Bolin | bc6ee58 | 2022-11-04 08:24:29 +0000 | [diff] [blame] | 16 | # |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 17 | # Description: |
| 18 | # Defines the basic numeric type classes for tensors. |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 19 | import enum |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 20 | from typing import Any |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 21 | |
James Peet | 7519d50 | 2021-07-19 16:47:58 +0100 | [diff] [blame] | 22 | import numpy as np |
| 23 | |
Diego Russo | ea6111a | 2020-04-14 18:41:58 +0100 | [diff] [blame] | 24 | from .numeric_util import round_up_divide |
| 25 | |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 26 | |
| 27 | class BaseType(enum.Flag): |
| 28 | Signed = 1 |
| 29 | Unsigned = 2 |
| 30 | Asymmetric = 4 |
| 31 | Int = 8 |
| 32 | SignedInt = Int | Signed |
| 33 | UnsignedInt = Int | Unsigned |
| 34 | AsymmSInt = Int | Asymmetric | Signed |
| 35 | AsymmUInt = Int | Asymmetric | Unsigned |
| 36 | Float = 16 |
| 37 | BFloat = 32 |
| 38 | Bool = 64 |
| 39 | String = 128 |
| 40 | Resource = 256 |
| 41 | Variant = 512 |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 42 | Complex = 1024 |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 43 | |
| 44 | |
| 45 | class DataType: |
| 46 | """Defines a data type. Consists of a base type, and the number of bits used for this type""" |
| 47 | |
| 48 | __slots__ = "type", "bits" |
| 49 | |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame] | 50 | int4: Any |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 51 | int8: Any |
| 52 | int16: Any |
| 53 | int32: Any |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame] | 54 | int48: Any |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 55 | int64: Any |
| 56 | uint8: Any |
| 57 | uint16: Any |
| 58 | uint32: Any |
| 59 | uint64: Any |
| 60 | quint4: Any |
| 61 | quint8: Any |
| 62 | quint12: Any |
| 63 | quint16: Any |
| 64 | quint32: Any |
| 65 | qint4: Any |
| 66 | qint8: Any |
| 67 | qint12: Any |
| 68 | qint16: Any |
| 69 | qint32: Any |
| 70 | float16: Any |
| 71 | float32: Any |
| 72 | float64: Any |
| 73 | string: Any |
| 74 | bool: Any |
| 75 | resource: Any |
| 76 | variant: Any |
| 77 | complex64: Any |
| 78 | complex128: Any |
| 79 | |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 80 | def __init__(self, type_, bits): |
| 81 | self.type = type_ |
| 82 | self.bits = bits |
| 83 | |
| 84 | def __eq__(self, other): |
| 85 | return self.type == other.type and self.bits == other.bits |
| 86 | |
| 87 | def __hash__(self): |
| 88 | return hash((self.type, self.bits)) |
| 89 | |
| 90 | def size_in_bytes(self): |
| 91 | return round_up_divide(self.bits, 8) |
| 92 | |
| 93 | def size_in_bits(self): |
| 94 | return self.bits |
| 95 | |
| 96 | def __str__(self): |
| 97 | stem, needs_format = DataType.stem_name[self.type] |
| 98 | if not needs_format: |
| 99 | return stem |
| 100 | else: |
| 101 | return stem % (self.bits,) |
| 102 | |
| 103 | __repr__ = __str__ |
| 104 | |
James Peet | 7519d50 | 2021-07-19 16:47:58 +0100 | [diff] [blame] | 105 | def as_numpy_type(self): |
| 106 | numpy_dtype_code = { |
| 107 | BaseType.UnsignedInt: "u", |
| 108 | BaseType.SignedInt: "i", |
| 109 | BaseType.Float: "f", |
| 110 | BaseType.Complex: "c", |
| 111 | } |
| 112 | assert self.type in numpy_dtype_code, f"Failed to interpret {self} as a numpy dtype" |
| 113 | return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes())) |
| 114 | |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 115 | stem_name = { |
| 116 | BaseType.UnsignedInt: ("uint%s", True), |
| 117 | BaseType.SignedInt: ("int%s", True), |
| 118 | BaseType.AsymmUInt: ("quint%s", True), |
| 119 | BaseType.AsymmSInt: ("qint%s", True), |
| 120 | BaseType.Float: ("float%s", True), |
| 121 | BaseType.BFloat: ("bfloat%s", True), |
| 122 | BaseType.Bool: ("bool", False), |
| 123 | BaseType.String: ("string", False), |
| 124 | BaseType.Resource: ("resource", False), |
| 125 | BaseType.Variant: ("variant", False), |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 126 | BaseType.Complex: ("complex%s", True), |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 127 | } |
| 128 | |
| 129 | |
| 130 | # generate the standard set of data types |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame] | 131 | DataType.int4 = DataType(BaseType.SignedInt, 4) |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 132 | DataType.int8 = DataType(BaseType.SignedInt, 8) |
| 133 | DataType.int16 = DataType(BaseType.SignedInt, 16) |
| 134 | DataType.int32 = DataType(BaseType.SignedInt, 32) |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame] | 135 | DataType.int48 = DataType(BaseType.SignedInt, 48) |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 136 | DataType.int64 = DataType(BaseType.SignedInt, 64) |
| 137 | |
| 138 | DataType.uint8 = DataType(BaseType.UnsignedInt, 8) |
| 139 | DataType.uint16 = DataType(BaseType.UnsignedInt, 16) |
| 140 | DataType.uint32 = DataType(BaseType.UnsignedInt, 32) |
| 141 | DataType.uint64 = DataType(BaseType.UnsignedInt, 64) |
| 142 | |
| 143 | DataType.quint4 = DataType(BaseType.AsymmUInt, 4) |
| 144 | DataType.quint8 = DataType(BaseType.AsymmUInt, 8) |
| 145 | DataType.quint12 = DataType(BaseType.AsymmUInt, 12) |
| 146 | DataType.quint16 = DataType(BaseType.AsymmUInt, 16) |
| 147 | DataType.quint32 = DataType(BaseType.AsymmUInt, 32) |
| 148 | |
| 149 | DataType.qint4 = DataType(BaseType.AsymmSInt, 4) |
| 150 | DataType.qint8 = DataType(BaseType.AsymmSInt, 8) |
| 151 | DataType.qint12 = DataType(BaseType.AsymmSInt, 12) |
| 152 | DataType.qint16 = DataType(BaseType.AsymmSInt, 16) |
| 153 | DataType.qint32 = DataType(BaseType.AsymmSInt, 32) |
| 154 | |
| 155 | DataType.float16 = DataType(BaseType.Float, 16) |
| 156 | DataType.float32 = DataType(BaseType.Float, 32) |
| 157 | DataType.float64 = DataType(BaseType.Float, 64) |
| 158 | |
| 159 | DataType.string = DataType(BaseType.String, 64) |
| 160 | DataType.bool = DataType(BaseType.Bool, 8) |
| 161 | DataType.resource = DataType(BaseType.Resource, 8) |
| 162 | DataType.variant = DataType(BaseType.Variant, 8) |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 163 | DataType.complex64 = DataType(BaseType.Complex, 64) |
Jacob Bohlin | 8daf6b7 | 2020-09-15 16:28:35 +0200 | [diff] [blame] | 164 | DataType.complex128 = DataType(BaseType.Complex, 128) |